1use arrow::compute::can_cast_types;
21use datafusion_expr::binary::BinaryTypeCoercer;
22use itertools::{Itertools as _, izip};
23use std::sync::{Arc, LazyLock};
24
25use crate::analyzer::AnalyzerRule;
26use crate::utils::NamePreserver;
27
28use arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit};
29use arrow::temporal_conversions::SECONDS_IN_DAY;
30use datafusion_common::config::ConfigOptions;
31use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
32use datafusion_common::{
33 Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference,
34 exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err,
35 plan_err,
36};
37use datafusion_expr::expr::{
38 self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists,
39 HigherOrderFunction, InList, InSubquery, Like, ScalarFunction, SetComparison, Sort,
40 WindowFunction,
41};
42use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
43use datafusion_expr::expr_schema::cast_subquery;
44use datafusion_expr::logical_plan::Subquery;
45use datafusion_expr::type_coercion::binary::{
46 comparison_coercion, like_coercion, type_union_coercion,
47};
48use datafusion_expr::type_coercion::functions::{
49 UDFCoercionExt, fields_with_udf, value_fields_with_higher_order_udf_and_lambdas,
50};
51use datafusion_expr::type_coercion::other::{
52 get_coerce_type_for_case_expression, get_coerce_type_for_case_when,
53 get_coerce_type_for_list,
54};
55use datafusion_expr::type_coercion::{
56 is_datetime, is_interval, is_signed_numeric, is_timestamp,
57};
58use datafusion_expr::utils::merge_schema;
59use datafusion_expr::{
60 Cast, Expr, ExprSchemable, Join, Limit, LogicalPlan, Operator, Projection, Union,
61 ValueOrLambda, WindowFrame, WindowFrameBound, WindowFrameUnits, is_false,
62 is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, lit, not,
63};
64
65#[derive(Default, Debug)]
68pub struct TypeCoercion {}
69
70impl TypeCoercion {
71 pub fn new() -> Self {
72 Self {}
73 }
74}
75
76fn coerce_output(plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
78 if !config.optimizer.expand_views_at_output {
79 return Ok(plan);
80 }
81
82 let outer_refs = plan.expressions();
83 if outer_refs.is_empty() {
84 return Ok(plan);
85 }
86
87 if let Some(dfschema) = transform_schema_to_nonview(plan.schema()) {
88 coerce_plan_expr_for_schema(plan, &dfschema?)
89 } else {
90 Ok(plan)
91 }
92}
93
94impl AnalyzerRule for TypeCoercion {
95 fn name(&self) -> &str {
96 "type_coercion"
97 }
98
99 fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
100 static EMPTY_SCHEMA: LazyLock<DFSchema> = LazyLock::new(DFSchema::empty);
101
102 let transformed_plan = plan
104 .transform_up_with_subqueries(|plan| analyze_internal(&EMPTY_SCHEMA, plan))?
105 .data;
106
107 coerce_output(transformed_plan, config)
109 }
110}
111
112fn analyze_internal(
116 external_schema: &DFSchema,
117 plan: LogicalPlan,
118) -> Result<Transformed<LogicalPlan>> {
119 let mut schema = merge_schema(&plan.inputs());
122
123 if let LogicalPlan::TableScan(ts) = &plan {
124 let source_schema = DFSchema::try_from_qualified_schema(
125 ts.table_name.clone(),
126 &ts.source.schema(),
127 )?;
128 schema.merge(&source_schema);
129 }
130
131 schema.merge(external_schema);
135
136 let plan = if let LogicalPlan::Filter(mut filter) = plan {
138 filter.predicate = filter.predicate.cast_to(&DataType::Boolean, &schema)?;
139 LogicalPlan::Filter(filter)
140 } else {
141 plan
142 };
143
144 let mut expr_rewrite = TypeCoercionRewriter::new(&schema);
145
146 let name_preserver = NamePreserver::new(&plan);
147 plan.map_expressions(|expr| {
149 let original_name = name_preserver.save(&expr);
150 expr.rewrite(&mut expr_rewrite)
151 .map(|transformed| transformed.update_data(|e| original_name.restore(e)))
152 })?
153 .map_data(|plan| expr_rewrite.coerce_plan(plan))?
155 .map_data(|plan| plan.recompute_schema())
157}
158
159pub struct TypeCoercionRewriter<'a> {
161 pub(crate) schema: &'a DFSchema,
162}
163
164impl<'a> TypeCoercionRewriter<'a> {
165 pub fn new(schema: &'a DFSchema) -> Self {
168 Self { schema }
169 }
170
171 pub fn coerce_plan(&mut self, plan: LogicalPlan) -> Result<LogicalPlan> {
176 match plan {
177 LogicalPlan::Join(join) => self.coerce_join(join),
178 LogicalPlan::Union(union) => Self::coerce_union(union),
179 LogicalPlan::Limit(limit) => Self::coerce_limit(limit),
180 _ => Ok(plan),
181 }
182 }
183
184 pub fn coerce_join(&mut self, mut join: Join) -> Result<LogicalPlan> {
193 join.on = join
194 .on
195 .into_iter()
196 .map(|(lhs, rhs)| {
197 let left_schema = join.left.schema();
200 let right_schema = join.right.schema();
201 let (lhs, rhs) = self.coerce_binary_op(
202 lhs,
203 left_schema,
204 Operator::Eq,
205 rhs,
206 right_schema,
207 )?;
208 Ok((lhs, rhs))
209 })
210 .collect::<Result<Vec<_>>>()?;
211
212 join.filter = join
214 .filter
215 .map(|expr| self.coerce_join_filter(expr))
216 .transpose()?;
217
218 Ok(LogicalPlan::Join(join))
219 }
220
221 pub fn coerce_union(union_plan: Union) -> Result<LogicalPlan> {
224 let union_schema = Arc::new(coerce_union_schema_with_schema(
225 &union_plan.inputs,
226 &union_plan.schema,
227 )?);
228 let new_inputs = union_plan
229 .inputs
230 .into_iter()
231 .map(|p| {
232 let plan =
233 coerce_plan_expr_for_schema(Arc::unwrap_or_clone(p), &union_schema)?;
234 match plan {
235 LogicalPlan::Projection(Projection { expr, input, .. }) => {
236 Ok(Arc::new(project_with_column_index(
237 expr,
238 input,
239 Arc::clone(&union_schema),
240 )?))
241 }
242 other_plan => Ok(Arc::new(other_plan)),
243 }
244 })
245 .collect::<Result<Vec<_>>>()?;
246 Ok(LogicalPlan::Union(Union {
247 inputs: new_inputs,
248 schema: union_schema,
249 }))
250 }
251
252 fn coerce_limit(limit: Limit) -> Result<LogicalPlan> {
254 fn coerce_limit_expr(
255 expr: Expr,
256 schema: &DFSchema,
257 expr_name: &str,
258 ) -> Result<Expr> {
259 let dt = expr.get_type(schema)?;
260 if dt.is_integer() || dt.is_null() {
261 expr.cast_to(&DataType::Int64, schema)
262 } else {
263 plan_err!("Expected {expr_name} to be an integer or null, but got {dt}")
264 }
265 }
266
267 let empty_schema = DFSchema::empty();
268 let new_fetch = limit
269 .fetch
270 .map(|expr| coerce_limit_expr(*expr, &empty_schema, "LIMIT"))
271 .transpose()?;
272 let new_skip = limit
273 .skip
274 .map(|expr| coerce_limit_expr(*expr, &empty_schema, "OFFSET"))
275 .transpose()?;
276 Ok(LogicalPlan::Limit(Limit {
277 input: limit.input,
278 fetch: new_fetch.map(Box::new),
279 skip: new_skip.map(Box::new),
280 }))
281 }
282
283 fn coerce_join_filter(&self, expr: Expr) -> Result<Expr> {
284 let expr_type = expr.get_type(self.schema)?;
285 match expr_type {
286 DataType::Boolean => Ok(expr),
287 DataType::Null => expr.cast_to(&DataType::Boolean, self.schema),
288 other => plan_err!("Join condition must be boolean type, but got {other:?}"),
289 }
290 }
291
292 fn coerce_binary_op(
293 &self,
294 left: Expr,
295 left_schema: &DFSchema,
296 op: Operator,
297 right: Expr,
298 right_schema: &DFSchema,
299 ) -> Result<(Expr, Expr)> {
300 let left_data_type = left.get_type(left_schema)?;
301 let right_data_type = right.get_type(right_schema)?;
302 let (left_type, right_type) =
303 BinaryTypeCoercer::new(&left_data_type, &op, &right_data_type)
304 .get_input_types()?;
305 let left_cast_ok = can_cast_types(&left_data_type, &left_type);
306 let right_cast_ok = can_cast_types(&right_data_type, &right_type);
307
308 let left_expr = if !left_cast_ok {
312 Self::coerce_date_time_math_op(
313 left,
314 &op,
315 &left_data_type,
316 &left_type,
317 &right_type,
318 )?
319 } else {
320 left.cast_to(&left_type, left_schema)?
321 };
322
323 let right_expr = if !right_cast_ok {
324 Self::coerce_date_time_math_op(
325 right,
326 &op,
327 &right_data_type,
328 &right_type,
329 &left_type,
330 )?
331 } else {
332 right.cast_to(&right_type, right_schema)?
333 };
334
335 Ok((left_expr, right_expr))
336 }
337
338 fn coerce_date_time_math_op(
339 expr: Expr,
340 op: &Operator,
341 left_current_type: &DataType,
342 left_target_type: &DataType,
343 right_target_type: &DataType,
344 ) -> Result<Expr, DataFusionError> {
345 use DataType::*;
346
347 fn cast(expr: Expr, target_type: DataType) -> Expr {
348 Expr::Cast(Cast::new(Box::new(expr), target_type))
349 }
350
351 fn time_to_nanos(
352 expr: Expr,
353 expr_type: &DataType,
354 ) -> Result<Expr, DataFusionError> {
355 let expr = match expr_type {
356 Time32(TimeUnit::Second) => {
357 cast(cast(expr, Int32), Int64)
358 * lit(ScalarValue::Int64(Some(1_000_000_000)))
359 }
360 Time32(TimeUnit::Millisecond) => {
361 cast(cast(expr, Int32), Int64)
362 * lit(ScalarValue::Int64(Some(1_000_000)))
363 }
364 Time64(TimeUnit::Microsecond) => {
365 cast(expr, Int64) * lit(ScalarValue::Int64(Some(1_000)))
366 }
367 Time64(TimeUnit::Nanosecond) => cast(expr, Int64),
368 t => return internal_err!("Unexpected time data type {t}"),
369 };
370
371 Ok(expr)
372 }
373
374 let e = match (
375 &op,
376 &left_current_type,
377 &left_target_type,
378 &right_target_type,
379 ) {
380 (
382 Operator::Plus | Operator::Minus,
383 Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64,
384 Interval(IntervalUnit::MonthDayNano),
385 Date32 | Date64,
386 ) => {
387 let expr = match *left_current_type {
389 Int64 => expr,
390 _ => cast(expr, Int64),
391 };
392 let expr = expr * lit(ScalarValue::from(SECONDS_IN_DAY));
394 let expr = cast(expr, Duration(TimeUnit::Second));
396 cast(expr, Interval(IntervalUnit::MonthDayNano))
398 }
399 (
408 Operator::Plus | Operator::Minus,
409 Time32(_) | Time64(_),
410 Duration(TimeUnit::Nanosecond),
411 Timestamp(TimeUnit::Nanosecond, None),
412 ) => {
413 let expr = time_to_nanos(expr, left_current_type)?;
415 cast(expr, Duration(TimeUnit::Nanosecond))
417 }
418 (
424 Operator::Plus | Operator::Minus,
425 Time32(_) | Time64(_),
426 Interval(IntervalUnit::MonthDayNano),
427 Interval(IntervalUnit::MonthDayNano),
428 ) => {
429 let expr = time_to_nanos(expr, left_current_type)?;
431 let expr = cast(expr, Duration(TimeUnit::Nanosecond));
433 cast(expr, Interval(IntervalUnit::MonthDayNano))
435 }
436 _ => {
437 return plan_err!(
438 "Cannot automatically convert {left_current_type} to {left_target_type}"
439 );
440 }
441 };
442
443 Ok(e)
444 }
445}
446
447impl TreeNodeRewriter for TypeCoercionRewriter<'_> {
448 type Node = Expr;
449
450 fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
451 match expr {
452 Expr::Unnest(_) => not_impl_err!(
453 "Unnest should be rewritten to LogicalPlan::Unnest before type coercion"
454 ),
455 Expr::ScalarSubquery(Subquery {
456 subquery,
457 outer_ref_columns,
458 spans,
459 }) => {
460 let new_plan =
461 analyze_internal(self.schema, Arc::unwrap_or_clone(subquery))?.data;
462 Ok(Transformed::yes(Expr::ScalarSubquery(Subquery {
463 subquery: Arc::new(new_plan),
464 outer_ref_columns,
465 spans,
466 })))
467 }
468 Expr::Exists(Exists { subquery, negated }) => {
469 let new_plan = analyze_internal(
470 self.schema,
471 Arc::unwrap_or_clone(subquery.subquery),
472 )?
473 .data;
474 Ok(Transformed::yes(Expr::Exists(Exists {
475 subquery: Subquery {
476 subquery: Arc::new(new_plan),
477 outer_ref_columns: subquery.outer_ref_columns,
478 spans: subquery.spans,
479 },
480 negated,
481 })))
482 }
483 Expr::InSubquery(InSubquery {
484 expr,
485 subquery,
486 negated,
487 }) => {
488 let new_plan = analyze_internal(
489 self.schema,
490 Arc::unwrap_or_clone(subquery.subquery),
491 )?
492 .data;
493 let expr_type = expr.get_type(self.schema)?;
494 let subquery_type = new_plan.schema().field(0).data_type();
495 let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(
496 plan_datafusion_err!(
497 "expr type {expr_type} can't cast to {subquery_type} in InSubquery"
498 ),
499 )?;
500 let new_subquery = Subquery {
501 subquery: Arc::new(new_plan),
502 outer_ref_columns: subquery.outer_ref_columns,
503 spans: subquery.spans,
504 };
505 Ok(Transformed::yes(Expr::InSubquery(InSubquery::new(
506 Box::new(expr.cast_to(&common_type, self.schema)?),
507 cast_subquery(new_subquery, &common_type)?,
508 negated,
509 ))))
510 }
511 Expr::SetComparison(SetComparison {
512 expr,
513 subquery,
514 op,
515 quantifier,
516 }) => {
517 let new_plan = analyze_internal(
518 self.schema,
519 Arc::unwrap_or_clone(subquery.subquery),
520 )?
521 .data;
522 let expr_type = expr.get_type(self.schema)?;
523 let subquery_type = new_plan.schema().field(0).data_type();
524 if (expr_type.is_numeric() && subquery_type.is_string())
525 || (subquery_type.is_numeric() && expr_type.is_string())
526 {
527 return plan_err!(
528 "expr type {expr_type} can't cast to {subquery_type} in SetComparison"
529 );
530 }
531 let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(
532 plan_datafusion_err!(
533 "expr type {expr_type} can't cast to {subquery_type} in SetComparison"
534 ),
535 )?;
536 let new_subquery = Subquery {
537 subquery: Arc::new(new_plan),
538 outer_ref_columns: subquery.outer_ref_columns,
539 spans: subquery.spans,
540 };
541 Ok(Transformed::yes(Expr::SetComparison(SetComparison::new(
542 Box::new(expr.cast_to(&common_type, self.schema)?),
543 cast_subquery(new_subquery, &common_type)?,
544 op,
545 quantifier,
546 ))))
547 }
548 Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op(
549 *expr,
550 self.schema,
551 )?))),
552 Expr::IsTrue(expr) => Ok(Transformed::yes(is_true(
553 get_casted_expr_for_bool_op(*expr, self.schema)?,
554 ))),
555 Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true(
556 get_casted_expr_for_bool_op(*expr, self.schema)?,
557 ))),
558 Expr::IsFalse(expr) => Ok(Transformed::yes(is_false(
559 get_casted_expr_for_bool_op(*expr, self.schema)?,
560 ))),
561 Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false(
562 get_casted_expr_for_bool_op(*expr, self.schema)?,
563 ))),
564 Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown(
565 get_casted_expr_for_bool_op(*expr, self.schema)?,
566 ))),
567 Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown(
568 get_casted_expr_for_bool_op(*expr, self.schema)?,
569 ))),
570 Expr::Negative(expr) => {
571 let data_type = expr.get_type(self.schema)?;
572 if data_type.is_null()
573 || is_signed_numeric(&data_type)
574 || is_interval(&data_type)
575 || is_timestamp(&data_type)
576 {
577 Ok(Transformed::no(Expr::Negative(expr)))
578 } else {
579 plan_err!(
580 "Negation only supports numeric, interval and timestamp types"
581 )
582 }
583 }
584 Expr::Like(Like {
585 negated,
586 expr,
587 pattern,
588 escape_char,
589 case_insensitive,
590 }) => {
591 let left_type = expr.get_type(self.schema)?;
592 let right_type = pattern.get_type(self.schema)?;
593 let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| {
594 let op_name = if case_insensitive {
595 "ILIKE"
596 } else {
597 "LIKE"
598 };
599 plan_datafusion_err!(
600 "There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression"
601 )
602 })?;
603 let expr = match left_type {
604 DataType::Dictionary(_, inner) if *inner == DataType::Utf8 => expr,
605 _ => Box::new(expr.cast_to(&coerced_type, self.schema)?),
606 };
607 let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?);
608 Ok(Transformed::yes(Expr::Like(Like::new(
609 negated,
610 expr,
611 pattern,
612 escape_char,
613 case_insensitive,
614 ))))
615 }
616 Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
617 let (left, right) =
618 self.coerce_binary_op(*left, self.schema, op, *right, self.schema)?;
619 Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new(
620 Box::new(left),
621 op,
622 Box::new(right),
623 ))))
624 }
625 Expr::Between(Between {
626 expr,
627 negated,
628 low,
629 high,
630 }) => {
631 let expr_type = expr.get_type(self.schema)?;
632 let low_type = low.get_type(self.schema)?;
633 let low_coerced_type = comparison_coercion(&expr_type, &low_type)
634 .ok_or_else(|| {
635 internal_datafusion_err!(
636 "Failed to coerce types {expr_type} and {low_type} in BETWEEN expression"
637 )
638 })?;
639 let high_type = high.get_type(self.schema)?;
640 let high_coerced_type = comparison_coercion(&expr_type, &high_type)
641 .ok_or_else(|| {
642 internal_datafusion_err!(
643 "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
644 )
645 })?;
646 let coercion_type =
647 comparison_coercion(&low_coerced_type, &high_coerced_type)
648 .ok_or_else(|| {
649 internal_datafusion_err!(
650 "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
651 )
652 })?;
653 Ok(Transformed::yes(Expr::Between(Between::new(
654 Box::new(expr.cast_to(&coercion_type, self.schema)?),
655 negated,
656 Box::new(low.cast_to(&coercion_type, self.schema)?),
657 Box::new(high.cast_to(&coercion_type, self.schema)?),
658 ))))
659 }
660 Expr::InList(InList {
661 expr,
662 list,
663 negated,
664 }) => {
665 let expr_data_type = expr.get_type(self.schema)?;
666 let list_data_types = list
667 .iter()
668 .map(|list_expr| list_expr.get_type(self.schema))
669 .collect::<Result<Vec<_>>>()?;
670 let result_type =
671 get_coerce_type_for_list(&expr_data_type, &list_data_types);
672 match result_type {
673 None => plan_err!(
674 "Can not find compatible types to compare {expr_data_type} with [{}]",
675 list_data_types.iter().join(", ")
676 ),
677 Some(coerced_type) => {
678 let cast_expr = expr.cast_to(&coerced_type, self.schema)?;
680 let cast_list_expr = list
681 .into_iter()
682 .map(|list_expr| {
683 list_expr.cast_to(&coerced_type, self.schema)
684 })
685 .collect::<Result<Vec<_>>>()?;
686 Ok(Transformed::yes(Expr::InList(InList::new(
687 Box::new(cast_expr),
688 cast_list_expr,
689 negated,
690 ))))
691 }
692 }
693 }
694 Expr::Case(case) => {
695 let case = coerce_case_expression(case, self.schema)?;
696 Ok(Transformed::yes(Expr::Case(case)))
697 }
698 Expr::ScalarFunction(ScalarFunction { func, args }) => {
699 let new_expr =
700 coerce_arguments_for_signature(args, self.schema, func.as_ref())?;
701 Ok(Transformed::yes(Expr::ScalarFunction(
702 ScalarFunction::new_udf(func, new_expr),
703 )))
704 }
705 Expr::AggregateFunction(expr::AggregateFunction {
706 func,
707 params:
708 AggregateFunctionParams {
709 args,
710 distinct,
711 filter,
712 order_by,
713 null_treatment,
714 },
715 }) => {
716 let new_expr =
717 coerce_arguments_for_signature(args, self.schema, func.as_ref())?;
718 Ok(Transformed::yes(Expr::AggregateFunction(
719 expr::AggregateFunction::new_udf(
720 func,
721 new_expr,
722 distinct,
723 filter,
724 order_by,
725 null_treatment,
726 ),
727 )))
728 }
729 Expr::WindowFunction(window_fun) => {
730 let WindowFunction {
731 fun,
732 params:
733 expr::WindowFunctionParams {
734 args,
735 partition_by,
736 order_by,
737 window_frame,
738 filter,
739 null_treatment,
740 distinct,
741 },
742 } = *window_fun;
743 let window_frame =
744 coerce_window_frame(window_frame, self.schema, &order_by)?;
745
746 let args = match &fun {
747 expr::WindowFunctionDefinition::AggregateUDF(udf) => {
748 coerce_arguments_for_signature(args, self.schema, udf.as_ref())?
749 }
750 expr::WindowFunctionDefinition::WindowUDF(udf) => {
751 coerce_arguments_for_signature(args, self.schema, udf.as_ref())?
752 }
753 };
754
755 let new_expr = Expr::from(WindowFunction {
756 fun,
757 params: expr::WindowFunctionParams {
758 args,
759 partition_by,
760 order_by,
761 window_frame,
762 filter,
763 null_treatment,
764 distinct,
765 },
766 });
767 Ok(Transformed::yes(new_expr))
768 }
769 Expr::HigherOrderFunction(HigherOrderFunction { func, args }) => {
770 let current_fields = args
771 .iter()
772 .map(|arg| match arg {
773 Expr::Lambda(lambda) => Ok(ValueOrLambda::Lambda(
774 lambda.body.to_field(self.schema)?.1,
775 )),
776 _ => Ok(ValueOrLambda::Value(arg.to_field(self.schema)?.1)),
777 })
778 .collect::<Result<Vec<_>>>()?;
779
780 let new_fields = value_fields_with_higher_order_udf_and_lambdas(
781 ¤t_fields,
782 func.as_ref(),
783 )?;
784
785 let new_args = std::iter::zip(args, new_fields)
786 .map(|(arg, new_field)| match (&arg, new_field) {
787 (Expr::Lambda(_lambda), ValueOrLambda::Lambda(_)) => Ok(arg),
788 (Expr::Lambda(_lambda), ValueOrLambda::Value(_)) => internal_err!("value_fields_with_higher_order_udf returned a value for a lambda argument"),
789 (_, ValueOrLambda::Value(new_field)) => arg.cast_to(new_field.data_type(), self.schema),
790 (_, ValueOrLambda::Lambda(_)) => internal_err!("value_fields_with_higher_order_udf returned a lambda for a value argument"),
791 })
792 .collect::<Result<_>>()?;
793
794 Ok(Transformed::yes(Expr::HigherOrderFunction(
795 HigherOrderFunction::new(func, new_args),
796 )))
797 }
798 #[expect(deprecated)]
800 Expr::Alias(_)
801 | Expr::Column(_)
802 | Expr::ScalarVariable(_, _)
803 | Expr::Literal(_, _)
804 | Expr::SimilarTo(_)
805 | Expr::IsNotNull(_)
806 | Expr::IsNull(_)
807 | Expr::Cast(_)
808 | Expr::TryCast(_)
809 | Expr::Wildcard { .. }
810 | Expr::GroupingSet(_)
811 | Expr::Placeholder(_)
812 | Expr::OuterReferenceColumn(_, _)
813 | Expr::Lambda(_)
814 | Expr::LambdaVariable(_) => Ok(Transformed::no(expr)),
815 }
816 }
817}
818
819fn transform_schema_to_nonview(dfschema: &DFSchemaRef) -> Option<Result<DFSchema>> {
821 let metadata = dfschema.as_arrow().metadata.clone();
822 let mut transformed = false;
823
824 let (qualifiers, transformed_fields): (Vec<Option<TableReference>>, Vec<Arc<Field>>) =
825 dfschema
826 .iter()
827 .map(|(qualifier, field)| match field.data_type() {
828 DataType::Utf8View => {
829 transformed = true;
830 (
831 qualifier.cloned() as Option<TableReference>,
832 Arc::new(Field::new(
833 field.name(),
834 DataType::LargeUtf8,
835 field.is_nullable(),
836 )),
837 )
838 }
839 DataType::BinaryView => {
840 transformed = true;
841 (
842 qualifier.cloned() as Option<TableReference>,
843 Arc::new(Field::new(
844 field.name(),
845 DataType::LargeBinary,
846 field.is_nullable(),
847 )),
848 )
849 }
850 _ => (
851 qualifier.cloned() as Option<TableReference>,
852 Arc::clone(field),
853 ),
854 })
855 .unzip();
856
857 if !transformed {
858 return None;
859 }
860
861 let schema = Schema::new_with_metadata(transformed_fields, metadata);
862 Some(DFSchema::from_field_specific_qualified_schema(
863 qualifiers,
864 &Arc::new(schema),
865 ))
866}
867
868fn coerce_scalar(target_type: &DataType, value: &ScalarValue) -> Result<ScalarValue> {
871 match value {
872 ScalarValue::Utf8(Some(val)) => {
874 ScalarValue::try_from_string(val.clone(), target_type)
875 }
876 s => {
877 if s.is_null() {
878 ScalarValue::try_from(target_type)
880 } else {
881 Ok(s.clone())
885 }
886 }
887 }
888}
889
890fn coerce_scalar_range_aware(
897 target_type: &DataType,
898 value: &ScalarValue,
899) -> Result<ScalarValue> {
900 coerce_scalar(target_type, value).or_else(|err| {
901 if let Some(largest_type) = get_widest_type_in_family(target_type) {
903 coerce_scalar(largest_type, value).map_or_else(
904 |_| exec_err!("Cannot cast {value} to {target_type}"),
905 |_| ScalarValue::try_from(target_type),
906 )
907 } else {
908 Err(err)
909 }
910 })
911}
912
913fn get_widest_type_in_family(given_type: &DataType) -> Option<&DataType> {
917 match given_type {
918 DataType::UInt8 | DataType::UInt16 | DataType::UInt32 => Some(&DataType::UInt64),
919 DataType::Int8 | DataType::Int16 | DataType::Int32 => Some(&DataType::Int64),
920 DataType::Float16 | DataType::Float32 => Some(&DataType::Float64),
921 _ => None,
922 }
923}
924
925fn coerce_frame_bound(
927 target_type: &DataType,
928 bound: WindowFrameBound,
929) -> Result<WindowFrameBound> {
930 match bound {
931 WindowFrameBound::Preceding(v) => {
932 coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Preceding)
933 }
934 WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow),
935 WindowFrameBound::Following(v) => {
936 coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Following)
937 }
938 }
939}
940
941fn extract_window_frame_target_type(col_type: &DataType) -> Result<DataType> {
942 if col_type.is_numeric()
943 || col_type.is_string()
944 || col_type.is_null()
945 || matches!(
946 col_type,
947 DataType::List(_)
948 | DataType::LargeList(_)
949 | DataType::FixedSizeList(_, _)
950 | DataType::Boolean
951 )
952 {
953 Ok(col_type.clone())
954 } else if is_datetime(col_type) {
955 Ok(DataType::Interval(IntervalUnit::MonthDayNano))
956 } else if let DataType::Dictionary(_, value_type) = col_type {
957 extract_window_frame_target_type(value_type)
958 } else {
959 internal_err!("Cannot run range queries on datatype: {col_type}")
960 }
961}
962
963fn coerce_window_frame(
966 window_frame: WindowFrame,
967 schema: &DFSchema,
968 expressions: &[Sort],
969) -> Result<WindowFrame> {
970 let mut window_frame = window_frame;
971 let target_type = match window_frame.units {
972 WindowFrameUnits::Range => {
973 let current_types = expressions
974 .first()
975 .map(|s| s.expr.get_type(schema))
976 .transpose()?;
977 if let Some(col_type) = current_types {
978 extract_window_frame_target_type(&col_type)?
979 } else {
980 return internal_err!("ORDER BY column cannot be empty");
981 }
982 }
983 WindowFrameUnits::Rows | WindowFrameUnits::Groups => DataType::UInt64,
984 };
985 window_frame.start_bound =
986 coerce_frame_bound(&target_type, window_frame.start_bound)?;
987 window_frame.end_bound = coerce_frame_bound(&target_type, window_frame.end_bound)?;
988 Ok(window_frame)
989}
990
991fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result<Expr> {
994 let left_type = expr.get_type(schema)?;
995 BinaryTypeCoercer::new(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)
996 .get_input_types()?;
997 expr.cast_to(&DataType::Boolean, schema)
998}
999
1000fn coerce_arguments_for_signature<F: UDFCoercionExt>(
1005 expressions: Vec<Expr>,
1006 schema: &DFSchema,
1007 func: &F,
1008) -> Result<Vec<Expr>> {
1009 let current_fields = expressions
1010 .iter()
1011 .map(|e| e.to_field(schema).map(|(_, f)| f))
1012 .collect::<Result<Vec<_>>>()?;
1013
1014 let coerced_types = fields_with_udf(¤t_fields, func)?
1015 .into_iter()
1016 .map(|f| f.data_type().clone())
1017 .collect::<Vec<_>>();
1018
1019 expressions
1020 .into_iter()
1021 .enumerate()
1022 .map(|(i, expr)| expr.cast_to(&coerced_types[i], schema))
1023 .collect()
1024}
1025
1026fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result<Case> {
1027 let case_type = case
1059 .expr
1060 .as_ref()
1061 .map(|expr| expr.get_type(schema))
1062 .transpose()?;
1063 let then_types = case
1064 .when_then_expr
1065 .iter()
1066 .map(|(_when, then)| then.get_type(schema))
1067 .collect::<Result<Vec<_>>>()?;
1068 let else_type = case
1069 .else_expr
1070 .as_ref()
1071 .map(|expr| expr.get_type(schema))
1072 .transpose()?;
1073
1074 let case_when_coerce_type = case_type
1076 .as_ref()
1077 .map(|case_type| {
1078 let when_types = case
1079 .when_then_expr
1080 .iter()
1081 .map(|(when, _then)| when.get_type(schema))
1082 .collect::<Result<Vec<_>>>()?;
1083 let coerced_type = get_coerce_type_for_case_when(&when_types, case_type);
1084 coerced_type.ok_or_else(|| {
1085 plan_datafusion_err!(
1086 "Failed to coerce case ({case_type}) and when ({}) \
1087 to common types in CASE WHEN expression",
1088 when_types.iter().join(", ")
1089 )
1090 })
1091 })
1092 .transpose()?;
1093 let then_else_coerce_type =
1094 get_coerce_type_for_case_expression(&then_types, else_type.as_ref()).ok_or_else(
1095 || {
1096 if let Some(else_type) = else_type {
1097 plan_datafusion_err!(
1098 "Failed to coerce then ({}) and else ({else_type}) \
1099 to common types in CASE WHEN expression",
1100 then_types.iter().join(", ")
1101 )
1102 } else {
1103 plan_datafusion_err!(
1104 "Failed to coerce then ({}) and else (None) \
1105 to common types in CASE WHEN expression",
1106 then_types.iter().join(", ")
1107 )
1108 }
1109 },
1110 )?;
1111
1112 let case_expr = case
1114 .expr
1115 .zip(case_when_coerce_type.as_ref())
1116 .map(|(case_expr, coercible_type)| case_expr.cast_to(coercible_type, schema))
1117 .transpose()?
1118 .map(Box::new);
1119 let when_then = case
1120 .when_then_expr
1121 .into_iter()
1122 .map(|(when, then)| {
1123 let when_type = case_when_coerce_type.as_ref().unwrap_or(&DataType::Boolean);
1124 let when = when.cast_to(when_type, schema).map_err(|e| {
1125 DataFusionError::Context(
1126 format!(
1127 "WHEN expressions in CASE couldn't be \
1128 converted to common type ({when_type})"
1129 ),
1130 Box::new(e),
1131 )
1132 })?;
1133 let then = then.cast_to(&then_else_coerce_type, schema)?;
1134 Ok((Box::new(when), Box::new(then)))
1135 })
1136 .collect::<Result<Vec<_>>>()?;
1137 let else_expr = case
1138 .else_expr
1139 .map(|expr| expr.cast_to(&then_else_coerce_type, schema))
1140 .transpose()?
1141 .map(Box::new);
1142
1143 Ok(Case::new(case_expr, when_then, else_expr))
1144}
1145
1146pub fn coerce_union_schema(inputs: &[Arc<LogicalPlan>]) -> Result<DFSchema> {
1188 coerce_union_schema_with_schema(&inputs[1..], inputs[0].schema())
1189}
1190fn coerce_union_schema_with_schema(
1191 inputs: &[Arc<LogicalPlan>],
1192 base_schema: &DFSchemaRef,
1193) -> Result<DFSchema> {
1194 let mut union_datatypes = base_schema
1195 .fields()
1196 .iter()
1197 .map(|f| f.data_type().clone())
1198 .collect::<Vec<_>>();
1199 let mut union_nullabilities = base_schema
1200 .fields()
1201 .iter()
1202 .map(|f| f.is_nullable())
1203 .collect::<Vec<_>>();
1204 let mut union_field_meta = base_schema
1205 .fields()
1206 .iter()
1207 .map(|f| f.metadata().clone())
1208 .collect::<Vec<_>>();
1209
1210 let mut metadata = base_schema.metadata().clone();
1211
1212 for (i, plan) in inputs.iter().enumerate() {
1213 let plan_schema = plan.schema();
1214 metadata.extend(plan_schema.metadata().clone());
1215
1216 if plan_schema.fields().len() != base_schema.fields().len() {
1217 return plan_err!(
1218 "Union schemas have different number of fields: \
1219 query 1 has {} fields whereas query {} has {} fields",
1220 base_schema.fields().len(),
1221 i + 1,
1222 plan_schema.fields().len()
1223 );
1224 }
1225
1226 for (union_datatype, union_nullable, union_field_map, plan_field) in izip!(
1228 union_datatypes.iter_mut(),
1229 union_nullabilities.iter_mut(),
1230 union_field_meta.iter_mut(),
1231 plan_schema.fields().iter()
1232 ) {
1233 let coerced_type =
1234 type_union_coercion(union_datatype, plan_field.data_type()).ok_or_else(
1235 || {
1236 plan_datafusion_err!(
1237 "Incompatible inputs for Union: Previous inputs were \
1238 of type {}, but got incompatible type {} on column '{}'",
1239 union_datatype,
1240 plan_field.data_type(),
1241 plan_field.name()
1242 )
1243 },
1244 )?;
1245
1246 *union_datatype = coerced_type;
1247 *union_nullable = *union_nullable || plan_field.is_nullable();
1248 union_field_map.extend(plan_field.metadata().clone());
1249 }
1250 }
1251 let union_qualified_fields = izip!(
1252 base_schema.fields(),
1253 union_datatypes.into_iter(),
1254 union_nullabilities,
1255 union_field_meta.into_iter()
1256 )
1257 .map(|(field, datatype, nullable, metadata)| {
1258 let mut field = Field::new(field.name().clone(), datatype, nullable);
1259 field.set_metadata(metadata);
1260 (None, field.into())
1261 })
1262 .collect::<Vec<_>>();
1263
1264 DFSchema::new_with_metadata(union_qualified_fields, metadata)
1265}
1266
1267fn project_with_column_index(
1269 expr: Vec<Expr>,
1270 input: Arc<LogicalPlan>,
1271 schema: DFSchemaRef,
1272) -> Result<LogicalPlan> {
1273 let alias_expr = expr
1274 .into_iter()
1275 .enumerate()
1276 .map(|(i, e)| match e {
1277 Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => {
1278 Ok(e.unalias().alias(schema.field(i).name()))
1279 }
1280 Expr::Column(Column {
1281 relation: _,
1282 ref name,
1283 spans: _,
1284 }) if name != schema.field(i).name() => Ok(e.alias(schema.field(i).name())),
1285 Expr::Alias { .. } | Expr::Column { .. } => Ok(e),
1286 #[expect(deprecated)]
1287 Expr::Wildcard { .. } => {
1288 plan_err!("Wildcard should be expanded before type coercion")
1289 }
1290 _ => Ok(e.alias(schema.field(i).name())),
1291 })
1292 .collect::<Result<Vec<_>>>()?;
1293
1294 Projection::try_new_with_schema(alias_expr, input, schema)
1295 .map(LogicalPlan::Projection)
1296}
1297
1298#[cfg(test)]
1299mod test {
1300
1301 use std::sync::Arc;
1302
1303 use arrow::datatypes::DataType::Utf8;
1304 use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder, TimeUnit};
1305 use insta::assert_snapshot;
1306
1307 use crate::analyzer::Analyzer;
1308 use crate::analyzer::type_coercion::{
1309 TypeCoercion, TypeCoercionRewriter, coerce_case_expression,
1310 };
1311 use crate::assert_analyzed_plan_with_config_eq_snapshot;
1312 use datafusion_common::config::ConfigOptions;
1313 use datafusion_common::tree_node::{TransformedResult, TreeNode};
1314 use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue, Spans};
1315 use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction};
1316 use datafusion_expr::logical_plan::{EmptyRelation, Projection, Sort};
1317 use datafusion_expr::test::function_stub::avg_udaf;
1318 use datafusion_expr::{
1319 AccumulatorFactoryFunction, AggregateUDF, BinaryExpr, Case, ColumnarValue, Expr,
1320 ExprSchemable, Filter, LogicalPlan, Operator, ScalarFunctionArgs, ScalarUDF,
1321 ScalarUDFImpl, Signature, SimpleAggregateUDF, Subquery, Union, Volatility, cast,
1322 col, create_udaf, is_true, lit,
1323 };
1324 use datafusion_functions_aggregate::average::AvgAccumulator;
1325 use datafusion_sql::TableReference;
1326
1327 fn empty() -> Arc<LogicalPlan> {
1328 Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1329 produce_one_row: false,
1330 schema: Arc::new(DFSchema::empty()),
1331 }))
1332 }
1333
1334 fn empty_with_type(data_type: DataType) -> Arc<LogicalPlan> {
1335 Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1336 produce_one_row: false,
1337 schema: Arc::new(
1338 DFSchema::from_unqualified_fields(
1339 vec![Field::new("a", data_type, true)].into(),
1340 std::collections::HashMap::new(),
1341 )
1342 .unwrap(),
1343 ),
1344 }))
1345 }
1346
1347 macro_rules! assert_analyzed_plan_eq {
1348 (
1349 $plan: expr,
1350 @ $expected: literal $(,)?
1351 ) => {{
1352 let options = ConfigOptions::default();
1353 let rule = Arc::new(TypeCoercion::new());
1354 assert_analyzed_plan_with_config_eq_snapshot!(
1355 options,
1356 rule,
1357 $plan,
1358 @ $expected,
1359 )
1360 }};
1361 }
1362
1363 macro_rules! coerce_on_output_if_viewtype {
1364 (
1365 $is_viewtype: expr,
1366 $plan: expr,
1367 @ $expected: literal $(,)?
1368 ) => {{
1369 let mut options = ConfigOptions::default();
1370 if $is_viewtype {options.optimizer.expand_views_at_output = true;}
1372 let rule = Arc::new(TypeCoercion::new());
1373
1374 assert_analyzed_plan_with_config_eq_snapshot!(
1375 options,
1376 rule,
1377 $plan,
1378 @ $expected,
1379 )
1380 }};
1381 }
1382
1383 fn assert_type_coercion_error(
1384 plan: LogicalPlan,
1385 expected_substr: &str,
1386 ) -> Result<()> {
1387 let options = ConfigOptions::default();
1388 let analyzer = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]);
1389
1390 match analyzer.execute_and_check(plan, &options, |_, _| {}) {
1391 Ok(succeeded_plan) => {
1392 panic!(
1393 "Expected a type coercion error, but analysis succeeded: \n{succeeded_plan:#?}"
1394 );
1395 }
1396 Err(e) => {
1397 let msg = e.to_string();
1398 assert!(
1399 msg.contains(expected_substr),
1400 "Error did not contain expected substring.\n expected to find: `{expected_substr}`\n actual error: `{msg}`"
1401 );
1402 }
1403 }
1404
1405 Ok(())
1406 }
1407
1408 #[test]
1409 fn simple_case() -> Result<()> {
1410 let expr = col("a").lt(lit(2_u32));
1411 let empty = empty_with_type(DataType::Float64);
1412 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1413
1414 assert_analyzed_plan_eq!(
1415 plan,
1416 @r"
1417 Projection: a < CAST(UInt32(2) AS Float64)
1418 EmptyRelation: rows=0
1419 "
1420 )
1421 }
1422
1423 #[test]
1424 fn negative_expr_wrapped_by_is_null_errors() -> Result<()> {
1425 let predicate = Expr::IsNull(Box::new(Expr::Negative(Box::new(lit("a")))));
1426 let plan = LogicalPlan::Filter(Filter::try_new(predicate, empty())?);
1427
1428 assert_type_coercion_error(
1429 plan,
1430 "Negation only supports numeric, interval and timestamp types",
1431 )
1432 }
1433
1434 #[test]
1435 fn test_coerce_union() -> Result<()> {
1436 let left_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1437 produce_one_row: false,
1438 schema: Arc::new(
1439 DFSchema::try_from_qualified_schema(
1440 TableReference::full("datafusion", "test", "foo"),
1441 &Schema::new(vec![Field::new("a", DataType::Int32, false)]),
1442 )
1443 .unwrap(),
1444 ),
1445 }));
1446 let right_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1447 produce_one_row: false,
1448 schema: Arc::new(
1449 DFSchema::try_from_qualified_schema(
1450 TableReference::full("datafusion", "test", "foo"),
1451 &Schema::new(vec![Field::new("a", DataType::Int64, false)]),
1452 )
1453 .unwrap(),
1454 ),
1455 }));
1456 let union = LogicalPlan::Union(Union::try_new_with_loose_types(vec![
1457 left_plan, right_plan,
1458 ])?);
1459 let analyzed_union = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())])
1460 .execute_and_check(union, &ConfigOptions::default(), |_, _| {})?;
1461 let top_level_plan = LogicalPlan::Projection(Projection::try_new(
1462 vec![col("a")],
1463 Arc::new(analyzed_union),
1464 )?);
1465
1466 assert_analyzed_plan_eq!(
1467 top_level_plan,
1468 @r"
1469 Projection: a
1470 Union
1471 Projection: CAST(datafusion.test.foo.a AS Int64) AS a
1472 EmptyRelation: rows=0
1473 EmptyRelation: rows=0
1474 "
1475 )
1476 }
1477
1478 #[test]
1479 fn coerce_utf8view_output() -> Result<()> {
1480 let expr = col("a");
1483 let empty = empty_with_type(DataType::Utf8View);
1484 let plan = LogicalPlan::Projection(Projection::try_new(
1485 vec![expr.clone()],
1486 Arc::clone(&empty),
1487 )?);
1488
1489 coerce_on_output_if_viewtype!(
1491 false,
1492 plan.clone(),
1493 @r"
1494 Projection: a
1495 EmptyRelation: rows=0
1496 "
1497 )?;
1498
1499 coerce_on_output_if_viewtype!(
1501 true,
1502 plan.clone(),
1503 @r"
1504 Projection: CAST(a AS LargeUtf8) AS a
1505 EmptyRelation: rows=0
1506 "
1507 )?;
1508
1509 let bool_expr = col("a").lt(lit("foo"));
1512 let bool_plan = LogicalPlan::Projection(Projection::try_new(
1513 vec![bool_expr],
1514 Arc::clone(&empty),
1515 )?);
1516 coerce_on_output_if_viewtype!(
1518 false,
1519 bool_plan.clone(),
1520 @r#"
1521 Projection: a < CAST(Utf8("foo") AS Utf8View)
1522 EmptyRelation: rows=0
1523 "#
1524 )?;
1525
1526 coerce_on_output_if_viewtype!(
1527 false,
1528 plan.clone(),
1529 @r"
1530 Projection: a
1531 EmptyRelation: rows=0
1532 "
1533 )?;
1534
1535 coerce_on_output_if_viewtype!(
1537 true,
1538 plan.clone(),
1539 @r"
1540 Projection: CAST(a AS LargeUtf8) AS a
1541 EmptyRelation: rows=0
1542 "
1543 )?;
1544
1545 let sort_expr = expr.sort(true, true);
1548 let sort_plan = LogicalPlan::Sort(Sort {
1549 expr: vec![sort_expr],
1550 input: Arc::new(plan),
1551 fetch: None,
1552 });
1553
1554 coerce_on_output_if_viewtype!(
1556 false,
1557 sort_plan.clone(),
1558 @r"
1559 Sort: a ASC NULLS FIRST
1560 Projection: a
1561 EmptyRelation: rows=0
1562 "
1563 )?;
1564
1565 coerce_on_output_if_viewtype!(
1567 true,
1568 sort_plan.clone(),
1569 @r"
1570 Projection: CAST(a AS LargeUtf8) AS a
1571 Sort: a ASC NULLS FIRST
1572 Projection: a
1573 EmptyRelation: rows=0
1574 "
1575 )?;
1576
1577 let plan = LogicalPlan::Projection(Projection::try_new(
1580 vec![col("a")],
1581 Arc::new(sort_plan),
1582 )?);
1583 coerce_on_output_if_viewtype!(
1585 false,
1586 plan.clone(),
1587 @r"
1588 Projection: a
1589 Sort: a ASC NULLS FIRST
1590 Projection: a
1591 EmptyRelation: rows=0
1592 "
1593 )?;
1594 coerce_on_output_if_viewtype!(
1596 true,
1597 plan.clone(),
1598 @r"
1599 Projection: CAST(a AS LargeUtf8) AS a
1600 Sort: a ASC NULLS FIRST
1601 Projection: a
1602 EmptyRelation: rows=0
1603 "
1604 )?;
1605
1606 Ok(())
1607 }
1608
1609 #[test]
1610 fn coerce_binaryview_output() -> Result<()> {
1611 let expr = col("a");
1614 let empty = empty_with_type(DataType::BinaryView);
1615 let plan = LogicalPlan::Projection(Projection::try_new(
1616 vec![expr.clone()],
1617 Arc::clone(&empty),
1618 )?);
1619
1620 coerce_on_output_if_viewtype!(
1622 false,
1623 plan.clone(),
1624 @r"
1625 Projection: a
1626 EmptyRelation: rows=0
1627 "
1628 )?;
1629
1630 coerce_on_output_if_viewtype!(
1632 true,
1633 plan.clone(),
1634 @r"
1635 Projection: CAST(a AS LargeBinary) AS a
1636 EmptyRelation: rows=0
1637 "
1638 )?;
1639
1640 let bool_expr = col("a").lt(lit(vec![8, 1, 8, 1]));
1643 let bool_plan = LogicalPlan::Projection(Projection::try_new(
1644 vec![bool_expr],
1645 Arc::clone(&empty),
1646 )?);
1647
1648 coerce_on_output_if_viewtype!(
1650 false,
1651 bool_plan.clone(),
1652 @r#"
1653 Projection: a < CAST(Binary("8,1,8,1") AS BinaryView)
1654 EmptyRelation: rows=0
1655 "#
1656 )?;
1657
1658 coerce_on_output_if_viewtype!(
1660 true,
1661 bool_plan.clone(),
1662 @r#"
1663 Projection: a < CAST(Binary("8,1,8,1") AS BinaryView)
1664 EmptyRelation: rows=0
1665 "#
1666 )?;
1667
1668 let sort_expr = expr.sort(true, true);
1671 let sort_plan = LogicalPlan::Sort(Sort {
1672 expr: vec![sort_expr],
1673 input: Arc::new(plan),
1674 fetch: None,
1675 });
1676
1677 coerce_on_output_if_viewtype!(
1679 false,
1680 sort_plan.clone(),
1681 @r"
1682 Sort: a ASC NULLS FIRST
1683 Projection: a
1684 EmptyRelation: rows=0
1685 "
1686 )?;
1687 coerce_on_output_if_viewtype!(
1689 true,
1690 sort_plan.clone(),
1691 @r"
1692 Projection: CAST(a AS LargeBinary) AS a
1693 Sort: a ASC NULLS FIRST
1694 Projection: a
1695 EmptyRelation: rows=0
1696 "
1697 )?;
1698
1699 let plan = LogicalPlan::Projection(Projection::try_new(
1702 vec![col("a")],
1703 Arc::new(sort_plan),
1704 )?);
1705
1706 coerce_on_output_if_viewtype!(
1708 false,
1709 plan.clone(),
1710 @r"
1711 Projection: a
1712 Sort: a ASC NULLS FIRST
1713 Projection: a
1714 EmptyRelation: rows=0
1715 "
1716 )?;
1717
1718 coerce_on_output_if_viewtype!(
1720 true,
1721 plan.clone(),
1722 @r"
1723 Projection: CAST(a AS LargeBinary) AS a
1724 Sort: a ASC NULLS FIRST
1725 Projection: a
1726 EmptyRelation: rows=0
1727 "
1728 )?;
1729
1730 Ok(())
1731 }
1732
1733 #[test]
1734 fn nested_case() -> Result<()> {
1735 let expr = col("a").lt(lit(2_u32));
1736 let empty = empty_with_type(DataType::Float64);
1737
1738 let plan = LogicalPlan::Projection(Projection::try_new(
1739 vec![expr.clone().or(expr)],
1740 empty,
1741 )?);
1742
1743 assert_analyzed_plan_eq!(
1744 plan,
1745 @r"
1746 Projection: a < CAST(UInt32(2) AS Float64) OR a < CAST(UInt32(2) AS Float64)
1747 EmptyRelation: rows=0
1748 "
1749 )
1750 }
1751
1752 #[derive(Debug, PartialEq, Eq, Hash)]
1753 struct TestScalarUDF {
1754 signature: Signature,
1755 }
1756
1757 impl ScalarUDFImpl for TestScalarUDF {
1758 fn name(&self) -> &str {
1759 "TestScalarUDF"
1760 }
1761
1762 fn signature(&self) -> &Signature {
1763 &self.signature
1764 }
1765
1766 fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1767 Ok(Utf8)
1768 }
1769
1770 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1771 Ok(ColumnarValue::Scalar(ScalarValue::from("a")))
1772 }
1773 }
1774
1775 #[test]
1776 fn scalar_udf() -> Result<()> {
1777 let empty = empty();
1778
1779 let udf = ScalarUDF::from(TestScalarUDF {
1780 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1781 })
1782 .call(vec![lit(123_i32)]);
1783 let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?);
1784
1785 assert_analyzed_plan_eq!(
1786 plan,
1787 @r"
1788 Projection: TestScalarUDF(CAST(Int32(123) AS Float32))
1789 EmptyRelation: rows=0
1790 "
1791 )
1792 }
1793
1794 #[test]
1795 fn scalar_udf_invalid_input() -> Result<()> {
1796 let empty = empty();
1797 let udf = ScalarUDF::from(TestScalarUDF {
1798 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1799 })
1800 .call(vec![lit("Apple")]);
1801 Projection::try_new(vec![udf], empty)
1802 .expect_err("Expected an error due to incorrect function input");
1803
1804 Ok(())
1805 }
1806
1807 #[test]
1808 fn scalar_function() -> Result<()> {
1809 let empty = empty();
1811 let lit_expr = lit(10i64);
1812 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
1813 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1814 });
1815 let scalar_function_expr =
1816 Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![lit_expr]));
1817 let plan = LogicalPlan::Projection(Projection::try_new(
1818 vec![scalar_function_expr],
1819 empty,
1820 )?);
1821
1822 assert_analyzed_plan_eq!(
1823 plan,
1824 @r"
1825 Projection: TestScalarUDF(CAST(Int64(10) AS Float32))
1826 EmptyRelation: rows=0
1827 "
1828 )
1829 }
1830
1831 #[test]
1832 fn agg_udaf() -> Result<()> {
1833 let empty = empty();
1834 let my_avg = create_udaf(
1835 "MY_AVG",
1836 vec![DataType::Float64],
1837 Arc::new(DataType::Float64),
1838 Volatility::Immutable,
1839 Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
1840 Arc::new(vec![DataType::UInt64, DataType::Float64]),
1841 );
1842 let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1843 Arc::new(my_avg),
1844 vec![lit(10i64)],
1845 false,
1846 None,
1847 vec![],
1848 None,
1849 ));
1850 let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?);
1851
1852 assert_analyzed_plan_eq!(
1853 plan,
1854 @r"
1855 Projection: MY_AVG(CAST(Int64(10) AS Float64))
1856 EmptyRelation: rows=0
1857 "
1858 )
1859 }
1860
1861 #[test]
1862 fn agg_udaf_invalid_input() -> Result<()> {
1863 let empty = empty();
1864 let return_type = DataType::Float64;
1865 let accumulator: AccumulatorFactoryFunction =
1866 Arc::new(|_| Ok(Box::<AvgAccumulator>::default()));
1867 let my_avg = AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
1868 "MY_AVG",
1869 Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
1870 return_type,
1871 accumulator,
1872 vec![
1873 Field::new("count", DataType::UInt64, true).into(),
1874 Field::new("avg", DataType::Float64, true).into(),
1875 ],
1876 ));
1877 let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1878 Arc::new(my_avg),
1879 vec![lit("10")],
1880 false,
1881 None,
1882 vec![],
1883 None,
1884 ));
1885
1886 let err = Projection::try_new(vec![udaf], empty).err().unwrap();
1887 assert!(
1888 err.strip_backtrace().starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'MY_AVG' function: coercion from Utf8 to the signature Uniform(1, [Float64]) failed")
1889 );
1890 Ok(())
1891 }
1892
1893 #[test]
1894 fn agg_function_case() -> Result<()> {
1895 let empty = empty();
1896 let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1897 avg_udaf(),
1898 vec![lit(12f64)],
1899 false,
1900 None,
1901 vec![],
1902 None,
1903 ));
1904 let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1905
1906 assert_analyzed_plan_eq!(
1907 plan,
1908 @r"
1909 Projection: avg(Float64(12))
1910 EmptyRelation: rows=0
1911 "
1912 )?;
1913
1914 let empty = empty_with_type(DataType::Int32);
1915 let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1916 avg_udaf(),
1917 vec![cast(col("a"), DataType::Float64)],
1918 false,
1919 None,
1920 vec![],
1921 None,
1922 ));
1923 let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1924
1925 assert_analyzed_plan_eq!(
1926 plan,
1927 @r"
1928 Projection: avg(CAST(a AS Float64))
1929 EmptyRelation: rows=0
1930 "
1931 )
1932 }
1933
1934 #[test]
1935 fn agg_function_invalid_input_avg() -> Result<()> {
1936 let empty = empty();
1937 let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1938 avg_udaf(),
1939 vec![lit("1")],
1940 false,
1941 None,
1942 vec![],
1943 None,
1944 ));
1945 let err = Projection::try_new(vec![agg_expr], empty)
1946 .err()
1947 .unwrap()
1948 .strip_backtrace();
1949 assert!(
1950 err.contains("Function 'avg' failed to match any signature"),
1951 "Err: {err:?}"
1952 );
1953 Ok(())
1954 }
1955
1956 #[test]
1957 fn binary_op_date32_op_interval() -> Result<()> {
1958 let expr = cast(lit("1998-03-18"), DataType::Date32)
1960 + lit(ScalarValue::new_interval_dt(123, 456));
1961 let empty = empty();
1962 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1963
1964 assert_analyzed_plan_eq!(
1965 plan,
1966 @r#"
1967 Projection: CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("IntervalDayTime { days: 123, milliseconds: 456 }")
1968 EmptyRelation: rows=0
1969 "#
1970 )
1971 }
1972
1973 #[test]
1974 fn inlist_case() -> Result<()> {
1975 let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1977 let empty = empty_with_type(DataType::Int64);
1978 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1979 assert_analyzed_plan_eq!(
1980 plan,
1981 @r"
1982 Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)])
1983 EmptyRelation: rows=0
1984 ")?;
1985
1986 let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1988 let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1989 produce_one_row: false,
1990 schema: Arc::new(DFSchema::from_unqualified_fields(
1991 vec![Field::new("a", DataType::Decimal128(12, 4), true)].into(),
1992 std::collections::HashMap::new(),
1993 )?),
1994 }));
1995 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1996 assert_analyzed_plan_eq!(
1997 plan,
1998 @r"
1999 Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))])
2000 EmptyRelation: rows=0
2001 ")
2002 }
2003
2004 #[test]
2005 fn between_case() -> Result<()> {
2006 let expr = col("a").between(
2007 lit("2002-05-08"),
2008 cast(lit("2002-05-08"), DataType::Date32)
2010 + lit(ScalarValue::new_interval_ym(0, 1)),
2011 );
2012 let empty = empty_with_type(Utf8);
2013 let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
2014
2015 assert_analyzed_plan_eq!(
2016 plan,
2017 @r#"
2018 Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) AND CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1")
2019 EmptyRelation: rows=0
2020 "#
2021 )
2022 }
2023
2024 #[test]
2025 fn between_infer_cheap_type() -> Result<()> {
2026 let expr = col("a").between(
2027 cast(lit("2002-05-08"), DataType::Date32)
2029 + lit(ScalarValue::new_interval_ym(0, 1)),
2030 lit("2002-12-08"),
2031 );
2032 let empty = empty_with_type(Utf8);
2033 let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
2034
2035 assert_analyzed_plan_eq!(
2037 plan,
2038 @r#"
2039 Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1") AND CAST(Utf8("2002-12-08") AS Date32)
2040 EmptyRelation: rows=0
2041 "#
2042 )
2043 }
2044
2045 #[test]
2046 fn between_null() -> Result<()> {
2047 let expr = lit(ScalarValue::Null).between(lit(ScalarValue::Null), lit(2i64));
2048 let empty = empty();
2049 let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
2050
2051 assert_analyzed_plan_eq!(
2052 plan,
2053 @r"
2054 Filter: CAST(NULL AS Int64) BETWEEN CAST(NULL AS Int64) AND Int64(2)
2055 EmptyRelation: rows=0
2056 "
2057 )
2058 }
2059
2060 #[test]
2061 fn is_bool_for_type_coercion() -> Result<()> {
2062 let expr = col("a").is_true();
2064 let empty = empty_with_type(DataType::Boolean);
2065 let plan =
2066 LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
2067
2068 assert_analyzed_plan_eq!(
2069 plan,
2070 @r"
2071 Projection: a IS TRUE
2072 EmptyRelation: rows=0
2073 "
2074 )?;
2075
2076 let empty = empty_with_type(DataType::Int64);
2077 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2078 assert_type_coercion_error(
2079 plan,
2080 "Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean",
2081 )?;
2082
2083 let expr = col("a").is_not_true();
2085 let empty = empty_with_type(DataType::Boolean);
2086 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2087
2088 assert_analyzed_plan_eq!(
2089 plan,
2090 @r"
2091 Projection: a IS NOT TRUE
2092 EmptyRelation: rows=0
2093 "
2094 )?;
2095
2096 let expr = col("a").is_false();
2098 let empty = empty_with_type(DataType::Boolean);
2099 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2100
2101 assert_analyzed_plan_eq!(
2102 plan,
2103 @r"
2104 Projection: a IS FALSE
2105 EmptyRelation: rows=0
2106 "
2107 )?;
2108
2109 let expr = col("a").is_not_false();
2111 let empty = empty_with_type(DataType::Boolean);
2112 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2113
2114 assert_analyzed_plan_eq!(
2115 plan,
2116 @r"
2117 Projection: a IS NOT FALSE
2118 EmptyRelation: rows=0
2119 "
2120 )
2121 }
2122
2123 #[test]
2124 fn like_for_type_coercion() -> Result<()> {
2125 let expr = Box::new(col("a"));
2127 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
2128 let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
2129 let empty = empty_with_type(Utf8);
2130 let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
2131
2132 assert_analyzed_plan_eq!(
2133 plan,
2134 @r#"
2135 Projection: a LIKE Utf8("abc")
2136 EmptyRelation: rows=0
2137 "#
2138 )?;
2139
2140 let expr = Box::new(col("a"));
2141 let pattern = Box::new(lit(ScalarValue::Null));
2142 let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
2143 let empty = empty_with_type(Utf8);
2144 let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
2145
2146 assert_analyzed_plan_eq!(
2147 plan,
2148 @r"
2149 Projection: a LIKE CAST(NULL AS Utf8)
2150 EmptyRelation: rows=0
2151 "
2152 )?;
2153
2154 let expr = Box::new(col("a"));
2155 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
2156 let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
2157 let empty = empty_with_type(DataType::Int64);
2158 let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
2159 assert_type_coercion_error(
2160 plan,
2161 "There isn't a common type to coerce Int64 and Utf8 in LIKE expression",
2162 )?;
2163
2164 let expr = Box::new(col("a"));
2166 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
2167 let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
2168 let empty = empty_with_type(Utf8);
2169 let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
2170
2171 assert_analyzed_plan_eq!(
2172 plan,
2173 @r#"
2174 Projection: a ILIKE Utf8("abc")
2175 EmptyRelation: rows=0
2176 "#
2177 )?;
2178
2179 let expr = Box::new(col("a"));
2180 let pattern = Box::new(lit(ScalarValue::Null));
2181 let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
2182 let empty = empty_with_type(Utf8);
2183 let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
2184
2185 assert_analyzed_plan_eq!(
2186 plan,
2187 @r"
2188 Projection: a ILIKE CAST(NULL AS Utf8)
2189 EmptyRelation: rows=0
2190 "
2191 )?;
2192
2193 let expr = Box::new(col("a"));
2194 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
2195 let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
2196 let empty = empty_with_type(DataType::Int64);
2197 let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
2198 assert_type_coercion_error(
2199 plan,
2200 "There isn't a common type to coerce Int64 and Utf8 in ILIKE expression",
2201 )?;
2202
2203 Ok(())
2204 }
2205
2206 #[test]
2207 fn unknown_for_type_coercion() -> Result<()> {
2208 let expr = col("a").is_unknown();
2210 let empty = empty_with_type(DataType::Boolean);
2211 let plan =
2212 LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
2213
2214 assert_analyzed_plan_eq!(
2215 plan,
2216 @r"
2217 Projection: a IS UNKNOWN
2218 EmptyRelation: rows=0
2219 "
2220 )?;
2221
2222 let empty = empty_with_type(Utf8);
2223 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2224 assert_type_coercion_error(
2225 plan,
2226 "Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean",
2227 )?;
2228
2229 let expr = col("a").is_not_unknown();
2231 let empty = empty_with_type(DataType::Boolean);
2232 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2233
2234 assert_analyzed_plan_eq!(
2235 plan,
2236 @r"
2237 Projection: a IS NOT UNKNOWN
2238 EmptyRelation: rows=0
2239 "
2240 )
2241 }
2242
2243 #[test]
2244 fn concat_for_type_coercion() -> Result<()> {
2245 let empty = empty_with_type(Utf8);
2246 let args = [col("a"), lit("b"), lit(true), lit(false), lit(13)];
2247
2248 let expr = ScalarUDF::new_from_impl(TestScalarUDF {
2250 signature: Signature::variadic(vec![Utf8], Volatility::Immutable),
2251 })
2252 .call(args.to_vec());
2253 let plan =
2254 LogicalPlan::Projection(Projection::try_new(vec![expr], Arc::clone(&empty))?);
2255 assert_analyzed_plan_eq!(
2256 plan,
2257 @r#"
2258 Projection: TestScalarUDF(a, Utf8("b"), CAST(Boolean(true) AS Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS Utf8))
2259 EmptyRelation: rows=0
2260 "#
2261 )
2262 }
2263
2264 #[test]
2265 fn test_type_coercion_rewrite() -> Result<()> {
2266 let schema = Arc::new(DFSchema::from_unqualified_fields(
2268 vec![Field::new("a", DataType::Int64, true)].into(),
2269 std::collections::HashMap::new(),
2270 )?);
2271 let mut rewriter = TypeCoercionRewriter { schema: &schema };
2272 let expr = is_true(lit(12i32).gt(lit(13i64)));
2273 let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64)));
2274 let result = expr.rewrite(&mut rewriter).data()?;
2275 assert_eq!(expected, result);
2276
2277 let schema = Arc::new(DFSchema::from_unqualified_fields(
2279 vec![Field::new("a", DataType::Int64, true)].into(),
2280 std::collections::HashMap::new(),
2281 )?);
2282 let mut rewriter = TypeCoercionRewriter { schema: &schema };
2283 let expr = is_true(lit(12i32).eq(lit(13i64)));
2284 let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64)));
2285 let result = expr.rewrite(&mut rewriter).data()?;
2286 assert_eq!(expected, result);
2287
2288 let schema = Arc::new(DFSchema::from_unqualified_fields(
2290 vec![Field::new("a", DataType::Int64, true)].into(),
2291 std::collections::HashMap::new(),
2292 )?);
2293 let mut rewriter = TypeCoercionRewriter { schema: &schema };
2294 let expr = is_true(lit(12i32).lt(lit(13i64)));
2295 let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64)));
2296 let result = expr.rewrite(&mut rewriter).data()?;
2297 assert_eq!(expected, result);
2298
2299 Ok(())
2300 }
2301
2302 #[test]
2303 fn binary_op_date32_eq_ts() -> Result<()> {
2304 let expr = cast(
2305 lit("1998-03-18"),
2306 DataType::Timestamp(TimeUnit::Nanosecond, None),
2307 )
2308 .eq(cast(lit("1998-03-18"), DataType::Date32));
2309 let empty = empty();
2310 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2311
2312 assert_analyzed_plan_eq!(
2313 plan,
2314 @r#"
2315 Projection: CAST(Utf8("1998-03-18") AS Timestamp(ns)) = CAST(CAST(Utf8("1998-03-18") AS Date32) AS Timestamp(ns))
2316 EmptyRelation: rows=0
2317 "#
2318 )
2319 }
2320
2321 fn cast_if_not_same_type(
2322 expr: Box<Expr>,
2323 data_type: &DataType,
2324 schema: &DFSchemaRef,
2325 ) -> Box<Expr> {
2326 if &expr.get_type(schema).unwrap() != data_type {
2327 Box::new(cast(*expr, data_type.clone()))
2328 } else {
2329 expr
2330 }
2331 }
2332
2333 fn cast_helper(
2334 case: Case,
2335 case_when_type: &DataType,
2336 then_else_type: &DataType,
2337 schema: &DFSchemaRef,
2338 ) -> Case {
2339 let expr = case
2340 .expr
2341 .map(|e| cast_if_not_same_type(e, case_when_type, schema));
2342 let when_then_expr = case
2343 .when_then_expr
2344 .into_iter()
2345 .map(|(when, then)| {
2346 (
2347 cast_if_not_same_type(when, case_when_type, schema),
2348 cast_if_not_same_type(then, then_else_type, schema),
2349 )
2350 })
2351 .collect::<Vec<_>>();
2352 let else_expr = case
2353 .else_expr
2354 .map(|e| cast_if_not_same_type(e, then_else_type, schema));
2355
2356 Case {
2357 expr,
2358 when_then_expr,
2359 else_expr,
2360 }
2361 }
2362
2363 #[test]
2364 fn test_case_expression_coercion() -> Result<()> {
2365 let schema = Arc::new(DFSchema::from_unqualified_fields(
2366 vec![
2367 Field::new("boolean", DataType::Boolean, true),
2368 Field::new("integer", DataType::Int32, true),
2369 Field::new("float", DataType::Float32, true),
2370 Field::new(
2371 "timestamp",
2372 DataType::Timestamp(TimeUnit::Nanosecond, None),
2373 true,
2374 ),
2375 Field::new("date", DataType::Date32, true),
2376 Field::new(
2377 "interval",
2378 DataType::Interval(arrow::datatypes::IntervalUnit::MonthDayNano),
2379 true,
2380 ),
2381 Field::new("binary", DataType::Binary, true),
2382 Field::new("string", Utf8, true),
2383 Field::new("decimal", DataType::Decimal128(10, 10), true),
2384 ]
2385 .into(),
2386 std::collections::HashMap::new(),
2387 )?);
2388
2389 let case = Case {
2390 expr: None,
2391 when_then_expr: vec![
2392 (Box::new(col("boolean")), Box::new(col("integer"))),
2393 (Box::new(col("integer")), Box::new(col("float"))),
2394 (Box::new(col("string")), Box::new(col("string"))),
2395 ],
2396 else_expr: None,
2397 };
2398 let case_when_common_type = DataType::Boolean;
2399 let then_else_common_type = Utf8;
2400 let expected = cast_helper(
2401 case.clone(),
2402 &case_when_common_type,
2403 &then_else_common_type,
2404 &schema,
2405 );
2406 let actual = coerce_case_expression(case, &schema)?;
2407 assert_eq!(expected, actual);
2408
2409 let case = Case {
2413 expr: Some(Box::new(col("string"))),
2414 when_then_expr: vec![
2415 (Box::new(col("float")), Box::new(col("integer"))),
2416 (Box::new(col("integer")), Box::new(col("float"))),
2417 (Box::new(col("string")), Box::new(col("string"))),
2418 ],
2419 else_expr: Some(Box::new(col("string"))),
2420 };
2421 let case_when_common_type = DataType::Float32;
2422 let then_else_common_type = Utf8;
2423 let expected = cast_helper(
2424 case.clone(),
2425 &case_when_common_type,
2426 &then_else_common_type,
2427 &schema,
2428 );
2429 let actual = coerce_case_expression(case, &schema)?;
2430 assert_eq!(expected, actual);
2431
2432 let case = Case {
2433 expr: Some(Box::new(col("interval"))),
2434 when_then_expr: vec![
2435 (Box::new(col("float")), Box::new(col("integer"))),
2436 (Box::new(col("binary")), Box::new(col("float"))),
2437 (Box::new(col("string")), Box::new(col("string"))),
2438 ],
2439 else_expr: Some(Box::new(col("string"))),
2440 };
2441 let err = coerce_case_expression(case, &schema).unwrap_err();
2442 assert_snapshot!(
2443 err.strip_backtrace(),
2444 @"Error during planning: Failed to coerce case (Interval(MonthDayNano)) and when (Float32, Binary, Utf8) to common types in CASE WHEN expression"
2445 );
2446
2447 let case = Case {
2448 expr: Some(Box::new(col("string"))),
2449 when_then_expr: vec![
2450 (Box::new(col("float")), Box::new(col("date"))),
2451 (Box::new(col("string")), Box::new(col("float"))),
2452 (Box::new(col("string")), Box::new(col("binary"))),
2453 ],
2454 else_expr: Some(Box::new(col("timestamp"))),
2455 };
2456 let err = coerce_case_expression(case, &schema).unwrap_err();
2457 assert_snapshot!(
2458 err.strip_backtrace(),
2459 @"Error during planning: Failed to coerce then (Date32, Float32, Binary) and else (Timestamp(ns)) to common types in CASE WHEN expression"
2460 );
2461
2462 Ok(())
2463 }
2464
2465 macro_rules! test_case_expression {
2466 ($expr:expr, $when_then:expr, $case_when_type:expr, $then_else_type:expr, $schema:expr) => {
2467 let case = Case {
2468 expr: $expr.map(|e| Box::new(col(e))),
2469 when_then_expr: $when_then,
2470 else_expr: None,
2471 };
2472
2473 let expected =
2474 cast_helper(case.clone(), &$case_when_type, &$then_else_type, &$schema);
2475
2476 let actual = coerce_case_expression(case, &$schema)?;
2477 assert_eq!(expected, actual);
2478 };
2479 }
2480
2481 #[test]
2482 fn tes_case_when_list() -> Result<()> {
2483 let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
2484 let schema = Arc::new(DFSchema::from_unqualified_fields(
2485 vec![
2486 Field::new(
2487 "large_list",
2488 DataType::LargeList(Arc::clone(&inner_field)),
2489 true,
2490 ),
2491 Field::new(
2492 "fixed_list",
2493 DataType::FixedSizeList(Arc::clone(&inner_field), 3),
2494 true,
2495 ),
2496 Field::new("list", DataType::List(inner_field), true),
2497 ]
2498 .into(),
2499 std::collections::HashMap::new(),
2500 )?);
2501
2502 test_case_expression!(
2503 Some("list"),
2504 vec![(Box::new(col("large_list")), Box::new(lit("1")))],
2505 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2506 Utf8,
2507 schema
2508 );
2509
2510 test_case_expression!(
2511 Some("large_list"),
2512 vec![(Box::new(col("list")), Box::new(lit("1")))],
2513 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2514 Utf8,
2515 schema
2516 );
2517
2518 test_case_expression!(
2519 Some("list"),
2520 vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
2521 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2522 Utf8,
2523 schema
2524 );
2525
2526 test_case_expression!(
2527 Some("fixed_list"),
2528 vec![(Box::new(col("list")), Box::new(lit("1")))],
2529 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2530 Utf8,
2531 schema
2532 );
2533
2534 test_case_expression!(
2535 Some("fixed_list"),
2536 vec![(Box::new(col("large_list")), Box::new(lit("1")))],
2537 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2538 Utf8,
2539 schema
2540 );
2541
2542 test_case_expression!(
2543 Some("large_list"),
2544 vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
2545 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2546 Utf8,
2547 schema
2548 );
2549 Ok(())
2550 }
2551
2552 #[test]
2553 fn test_then_else_list() -> Result<()> {
2554 let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
2555 let schema = Arc::new(DFSchema::from_unqualified_fields(
2556 vec![
2557 Field::new("boolean", DataType::Boolean, true),
2558 Field::new(
2559 "large_list",
2560 DataType::LargeList(Arc::clone(&inner_field)),
2561 true,
2562 ),
2563 Field::new(
2564 "fixed_list",
2565 DataType::FixedSizeList(Arc::clone(&inner_field), 3),
2566 true,
2567 ),
2568 Field::new("list", DataType::List(inner_field), true),
2569 ]
2570 .into(),
2571 std::collections::HashMap::new(),
2572 )?);
2573
2574 test_case_expression!(
2576 None::<String>,
2577 vec![
2578 (Box::new(col("boolean")), Box::new(col("large_list"))),
2579 (Box::new(col("boolean")), Box::new(col("list")))
2580 ],
2581 DataType::Boolean,
2582 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2583 schema
2584 );
2585
2586 test_case_expression!(
2587 None::<String>,
2588 vec![
2589 (Box::new(col("boolean")), Box::new(col("list"))),
2590 (Box::new(col("boolean")), Box::new(col("large_list")))
2591 ],
2592 DataType::Boolean,
2593 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2594 schema
2595 );
2596
2597 test_case_expression!(
2599 None::<String>,
2600 vec![
2601 (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2602 (Box::new(col("boolean")), Box::new(col("list")))
2603 ],
2604 DataType::Boolean,
2605 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2606 schema
2607 );
2608
2609 test_case_expression!(
2610 None::<String>,
2611 vec![
2612 (Box::new(col("boolean")), Box::new(col("list"))),
2613 (Box::new(col("boolean")), Box::new(col("fixed_list")))
2614 ],
2615 DataType::Boolean,
2616 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2617 schema
2618 );
2619
2620 test_case_expression!(
2622 None::<String>,
2623 vec![
2624 (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2625 (Box::new(col("boolean")), Box::new(col("large_list")))
2626 ],
2627 DataType::Boolean,
2628 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2629 schema
2630 );
2631
2632 test_case_expression!(
2633 None::<String>,
2634 vec![
2635 (Box::new(col("boolean")), Box::new(col("large_list"))),
2636 (Box::new(col("boolean")), Box::new(col("fixed_list")))
2637 ],
2638 DataType::Boolean,
2639 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2640 schema
2641 );
2642 Ok(())
2643 }
2644
2645 #[test]
2646 fn test_map_with_diff_name() -> Result<()> {
2647 let mut builder = SchemaBuilder::new();
2648 builder.push(Field::new("key", Utf8, false));
2649 builder.push(Field::new("value", DataType::Float64, true));
2650 let struct_fields = builder.finish().fields;
2651
2652 let fields =
2653 Field::new("entries", DataType::Struct(struct_fields.clone()), false);
2654 let map_type_entries = DataType::Map(Arc::new(fields), false);
2655
2656 let fields = Field::new("key_value", DataType::Struct(struct_fields), false);
2657 let may_type_custom = DataType::Map(Arc::new(fields), false);
2658
2659 let expr = col("a").eq(cast(col("a"), may_type_custom));
2660 let empty = empty_with_type(map_type_entries);
2661 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2662
2663 assert_analyzed_plan_eq!(
2664 plan,
2665 @r#"
2666 Projection: a = CAST(CAST(a AS Map("key_value": non-null Struct("key": non-null Utf8, "value": Float64), unsorted)) AS Map("entries": non-null Struct("key": non-null Utf8, "value": Float64), unsorted))
2667 EmptyRelation: rows=0
2668 "#
2669 )
2670 }
2671
2672 #[test]
2673 fn interval_plus_timestamp() -> Result<()> {
2674 let expr = Expr::BinaryExpr(BinaryExpr::new(
2676 Box::new(lit(ScalarValue::IntervalYearMonth(Some(12)))),
2677 Operator::Plus,
2678 Box::new(cast(
2679 lit("2000-01-01T00:00:00"),
2680 DataType::Timestamp(TimeUnit::Nanosecond, None),
2681 )),
2682 ));
2683 let empty = empty();
2684 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2685
2686 assert_analyzed_plan_eq!(
2687 plan,
2688 @r#"
2689 Projection: IntervalYearMonth("12") + CAST(Utf8("2000-01-01T00:00:00") AS Timestamp(ns))
2690 EmptyRelation: rows=0
2691 "#
2692 )
2693 }
2694
2695 #[test]
2696 fn timestamp_subtract_timestamp() -> Result<()> {
2697 let expr = Expr::BinaryExpr(BinaryExpr::new(
2698 Box::new(cast(
2699 lit("1998-03-18"),
2700 DataType::Timestamp(TimeUnit::Nanosecond, None),
2701 )),
2702 Operator::Minus,
2703 Box::new(cast(
2704 lit("1998-03-18"),
2705 DataType::Timestamp(TimeUnit::Nanosecond, None),
2706 )),
2707 ));
2708 let empty = empty();
2709 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2710
2711 assert_analyzed_plan_eq!(
2712 plan,
2713 @r#"
2714 Projection: CAST(Utf8("1998-03-18") AS Timestamp(ns)) - CAST(Utf8("1998-03-18") AS Timestamp(ns))
2715 EmptyRelation: rows=0
2716 "#
2717 )
2718 }
2719
2720 #[test]
2721 fn in_subquery_cast_subquery() -> Result<()> {
2722 let empty_int32 = empty_with_type(DataType::Int32);
2723 let empty_int64 = empty_with_type(DataType::Int64);
2724
2725 let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2726 Box::new(col("a")),
2727 Subquery {
2728 subquery: empty_int32,
2729 outer_ref_columns: vec![],
2730 spans: Spans::new(),
2731 },
2732 false,
2733 ));
2734 let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int64)?);
2735 assert_analyzed_plan_eq!(
2738 plan,
2739 @r"
2740 Filter: a IN (<subquery>)
2741 Subquery:
2742 Projection: CAST(a AS Int64)
2743 EmptyRelation: rows=0
2744 EmptyRelation: rows=0
2745 "
2746 )
2747 }
2748
2749 #[test]
2750 fn in_subquery_cast_expr() -> Result<()> {
2751 let empty_int32 = empty_with_type(DataType::Int32);
2752 let empty_int64 = empty_with_type(DataType::Int64);
2753
2754 let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2755 Box::new(col("a")),
2756 Subquery {
2757 subquery: empty_int64,
2758 outer_ref_columns: vec![],
2759 spans: Spans::new(),
2760 },
2761 false,
2762 ));
2763 let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int32)?);
2764
2765 assert_analyzed_plan_eq!(
2767 plan,
2768 @r"
2769 Filter: CAST(a AS Int64) IN (<subquery>)
2770 Subquery:
2771 EmptyRelation: rows=0
2772 EmptyRelation: rows=0
2773 "
2774 )
2775 }
2776
2777 #[test]
2778 fn in_subquery_cast_all() -> Result<()> {
2779 let empty_inside = empty_with_type(DataType::Decimal128(10, 5));
2780 let empty_outside = empty_with_type(DataType::Decimal128(8, 8));
2781
2782 let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2783 Box::new(col("a")),
2784 Subquery {
2785 subquery: empty_inside,
2786 outer_ref_columns: vec![],
2787 spans: Spans::new(),
2788 },
2789 false,
2790 ));
2791 let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_outside)?);
2792
2793 assert_analyzed_plan_eq!(
2795 plan,
2796 @r"
2797 Filter: CAST(a AS Decimal128(13, 8)) IN (<subquery>)
2798 Subquery:
2799 Projection: CAST(a AS Decimal128(13, 8))
2800 EmptyRelation: rows=0
2801 EmptyRelation: rows=0
2802 "
2803 )
2804 }
2805}