1use std::sync::Arc;
21
22use datafusion_expr::binary::BinaryTypeCoercer;
23use itertools::izip;
24
25use arrow::datatypes::{DataType, Field, IntervalUnit, Schema};
26
27use crate::analyzer::AnalyzerRule;
28use crate::utils::NamePreserver;
29use datafusion_common::config::ConfigOptions;
30use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
31use datafusion_common::{
32 exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column,
33 DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference,
34};
35use datafusion_expr::expr::{
36 self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, InList,
37 InSubquery, Like, ScalarFunction, Sort, WindowFunction,
38};
39use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
40use datafusion_expr::expr_schema::cast_subquery;
41use datafusion_expr::logical_plan::Subquery;
42use datafusion_expr::type_coercion::binary::{comparison_coercion, like_coercion};
43use datafusion_expr::type_coercion::functions::{
44 data_types_with_aggregate_udf, data_types_with_scalar_udf,
45};
46use datafusion_expr::type_coercion::other::{
47 get_coerce_type_for_case_expression, get_coerce_type_for_list,
48};
49use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_utf8view_or_large_utf8};
50use datafusion_expr::utils::merge_schema;
51use datafusion_expr::{
52 is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not,
53 AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, Join, Limit, LogicalPlan,
54 Operator, Projection, ScalarUDF, Union, WindowFrame, WindowFrameBound,
55 WindowFrameUnits,
56};
57
58#[derive(Default, Debug)]
61pub struct TypeCoercion {}
62
63impl TypeCoercion {
64 pub fn new() -> Self {
65 Self {}
66 }
67}
68
69fn coerce_output(plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
71 if !config.optimizer.expand_views_at_output {
72 return Ok(plan);
73 }
74
75 let outer_refs = plan.expressions();
76 if outer_refs.is_empty() {
77 return Ok(plan);
78 }
79
80 if let Some(dfschema) = transform_schema_to_nonview(plan.schema()) {
81 coerce_plan_expr_for_schema(plan, &dfschema?)
82 } else {
83 Ok(plan)
84 }
85}
86
87impl AnalyzerRule for TypeCoercion {
88 fn name(&self) -> &str {
89 "type_coercion"
90 }
91
92 fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
93 let empty_schema = DFSchema::empty();
94
95 let transformed_plan = plan
97 .transform_up_with_subqueries(|plan| analyze_internal(&empty_schema, plan))?
98 .data;
99
100 coerce_output(transformed_plan, config)
102 }
103}
104
105fn analyze_internal(
109 external_schema: &DFSchema,
110 plan: LogicalPlan,
111) -> Result<Transformed<LogicalPlan>> {
112 let mut schema = merge_schema(&plan.inputs());
115
116 if let LogicalPlan::TableScan(ts) = &plan {
117 let source_schema = DFSchema::try_from_qualified_schema(
118 ts.table_name.clone(),
119 &ts.source.schema(),
120 )?;
121 schema.merge(&source_schema);
122 }
123
124 schema.merge(external_schema);
128
129 let plan = if let LogicalPlan::Filter(mut filter) = plan {
131 filter.predicate = filter.predicate.cast_to(&DataType::Boolean, &schema)?;
132 LogicalPlan::Filter(filter)
133 } else {
134 plan
135 };
136
137 let mut expr_rewrite = TypeCoercionRewriter::new(&schema);
138
139 let name_preserver = NamePreserver::new(&plan);
140 plan.map_expressions(|expr| {
142 let original_name = name_preserver.save(&expr);
143 expr.rewrite(&mut expr_rewrite)
144 .map(|transformed| transformed.update_data(|e| original_name.restore(e)))
145 })?
146 .map_data(|plan| expr_rewrite.coerce_plan(plan))?
148 .map_data(|plan| plan.recompute_schema())
150}
151
152pub struct TypeCoercionRewriter<'a> {
154 pub(crate) schema: &'a DFSchema,
155}
156
157impl<'a> TypeCoercionRewriter<'a> {
158 pub fn new(schema: &'a DFSchema) -> Self {
161 Self { schema }
162 }
163
164 pub fn coerce_plan(&mut self, plan: LogicalPlan) -> Result<LogicalPlan> {
169 match plan {
170 LogicalPlan::Join(join) => self.coerce_join(join),
171 LogicalPlan::Union(union) => Self::coerce_union(union),
172 LogicalPlan::Limit(limit) => Self::coerce_limit(limit),
173 _ => Ok(plan),
174 }
175 }
176
177 pub fn coerce_join(&mut self, mut join: Join) -> Result<LogicalPlan> {
186 join.on = join
187 .on
188 .into_iter()
189 .map(|(lhs, rhs)| {
190 let left_schema = join.left.schema();
193 let right_schema = join.right.schema();
194 let (lhs, rhs) = self.coerce_binary_op(
195 lhs,
196 left_schema,
197 Operator::Eq,
198 rhs,
199 right_schema,
200 )?;
201 Ok((lhs, rhs))
202 })
203 .collect::<Result<Vec<_>>>()?;
204
205 join.filter = join
207 .filter
208 .map(|expr| self.coerce_join_filter(expr))
209 .transpose()?;
210
211 Ok(LogicalPlan::Join(join))
212 }
213
214 pub fn coerce_union(union_plan: Union) -> Result<LogicalPlan> {
217 let union_schema = Arc::new(coerce_union_schema_with_schema(
218 &union_plan.inputs,
219 &union_plan.schema,
220 )?);
221 let new_inputs = union_plan
222 .inputs
223 .into_iter()
224 .map(|p| {
225 let plan =
226 coerce_plan_expr_for_schema(Arc::unwrap_or_clone(p), &union_schema)?;
227 match plan {
228 LogicalPlan::Projection(Projection { expr, input, .. }) => {
229 Ok(Arc::new(project_with_column_index(
230 expr,
231 input,
232 Arc::clone(&union_schema),
233 )?))
234 }
235 other_plan => Ok(Arc::new(other_plan)),
236 }
237 })
238 .collect::<Result<Vec<_>>>()?;
239 Ok(LogicalPlan::Union(Union {
240 inputs: new_inputs,
241 schema: union_schema,
242 }))
243 }
244
245 fn coerce_limit(limit: Limit) -> Result<LogicalPlan> {
247 fn coerce_limit_expr(
248 expr: Expr,
249 schema: &DFSchema,
250 expr_name: &str,
251 ) -> Result<Expr> {
252 let dt = expr.get_type(schema)?;
253 if dt.is_integer() || dt.is_null() {
254 expr.cast_to(&DataType::Int64, schema)
255 } else {
256 plan_err!("Expected {expr_name} to be an integer or null, but got {dt:?}")
257 }
258 }
259
260 let empty_schema = DFSchema::empty();
261 let new_fetch = limit
262 .fetch
263 .map(|expr| coerce_limit_expr(*expr, &empty_schema, "LIMIT"))
264 .transpose()?;
265 let new_skip = limit
266 .skip
267 .map(|expr| coerce_limit_expr(*expr, &empty_schema, "OFFSET"))
268 .transpose()?;
269 Ok(LogicalPlan::Limit(Limit {
270 input: limit.input,
271 fetch: new_fetch.map(Box::new),
272 skip: new_skip.map(Box::new),
273 }))
274 }
275
276 fn coerce_join_filter(&self, expr: Expr) -> Result<Expr> {
277 let expr_type = expr.get_type(self.schema)?;
278 match expr_type {
279 DataType::Boolean => Ok(expr),
280 DataType::Null => expr.cast_to(&DataType::Boolean, self.schema),
281 other => plan_err!("Join condition must be boolean type, but got {other:?}"),
282 }
283 }
284
285 fn coerce_binary_op(
286 &self,
287 left: Expr,
288 left_schema: &DFSchema,
289 op: Operator,
290 right: Expr,
291 right_schema: &DFSchema,
292 ) -> Result<(Expr, Expr)> {
293 let (left_type, right_type) = BinaryTypeCoercer::new(
294 &left.get_type(left_schema)?,
295 &op,
296 &right.get_type(right_schema)?,
297 )
298 .get_input_types()?;
299
300 Ok((
301 left.cast_to(&left_type, left_schema)?,
302 right.cast_to(&right_type, right_schema)?,
303 ))
304 }
305}
306
307impl TreeNodeRewriter for TypeCoercionRewriter<'_> {
308 type Node = Expr;
309
310 fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
311 match expr {
312 Expr::Unnest(_) => not_impl_err!(
313 "Unnest should be rewritten to LogicalPlan::Unnest before type coercion"
314 ),
315 Expr::ScalarSubquery(Subquery {
316 subquery,
317 outer_ref_columns,
318 spans,
319 }) => {
320 let new_plan =
321 analyze_internal(self.schema, Arc::unwrap_or_clone(subquery))?.data;
322 Ok(Transformed::yes(Expr::ScalarSubquery(Subquery {
323 subquery: Arc::new(new_plan),
324 outer_ref_columns,
325 spans,
326 })))
327 }
328 Expr::Exists(Exists { subquery, negated }) => {
329 let new_plan = analyze_internal(
330 self.schema,
331 Arc::unwrap_or_clone(subquery.subquery),
332 )?
333 .data;
334 Ok(Transformed::yes(Expr::Exists(Exists {
335 subquery: Subquery {
336 subquery: Arc::new(new_plan),
337 outer_ref_columns: subquery.outer_ref_columns,
338 spans: subquery.spans,
339 },
340 negated,
341 })))
342 }
343 Expr::InSubquery(InSubquery {
344 expr,
345 subquery,
346 negated,
347 }) => {
348 let new_plan = analyze_internal(
349 self.schema,
350 Arc::unwrap_or_clone(subquery.subquery),
351 )?
352 .data;
353 let expr_type = expr.get_type(self.schema)?;
354 let subquery_type = new_plan.schema().field(0).data_type();
355 let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(plan_datafusion_err!(
356 "expr type {expr_type:?} can't cast to {subquery_type:?} in InSubquery"
357 ),
358 )?;
359 let new_subquery = Subquery {
360 subquery: Arc::new(new_plan),
361 outer_ref_columns: subquery.outer_ref_columns,
362 spans: subquery.spans,
363 };
364 Ok(Transformed::yes(Expr::InSubquery(InSubquery::new(
365 Box::new(expr.cast_to(&common_type, self.schema)?),
366 cast_subquery(new_subquery, &common_type)?,
367 negated,
368 ))))
369 }
370 Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op(
371 *expr,
372 self.schema,
373 )?))),
374 Expr::IsTrue(expr) => Ok(Transformed::yes(is_true(
375 get_casted_expr_for_bool_op(*expr, self.schema)?,
376 ))),
377 Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true(
378 get_casted_expr_for_bool_op(*expr, self.schema)?,
379 ))),
380 Expr::IsFalse(expr) => Ok(Transformed::yes(is_false(
381 get_casted_expr_for_bool_op(*expr, self.schema)?,
382 ))),
383 Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false(
384 get_casted_expr_for_bool_op(*expr, self.schema)?,
385 ))),
386 Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown(
387 get_casted_expr_for_bool_op(*expr, self.schema)?,
388 ))),
389 Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown(
390 get_casted_expr_for_bool_op(*expr, self.schema)?,
391 ))),
392 Expr::Like(Like {
393 negated,
394 expr,
395 pattern,
396 escape_char,
397 case_insensitive,
398 }) => {
399 let left_type = expr.get_type(self.schema)?;
400 let right_type = pattern.get_type(self.schema)?;
401 let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| {
402 let op_name = if case_insensitive {
403 "ILIKE"
404 } else {
405 "LIKE"
406 };
407 plan_datafusion_err!(
408 "There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression"
409 )
410 })?;
411 let expr = match left_type {
412 DataType::Dictionary(_, inner) if *inner == DataType::Utf8 => expr,
413 _ => Box::new(expr.cast_to(&coerced_type, self.schema)?),
414 };
415 let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?);
416 Ok(Transformed::yes(Expr::Like(Like::new(
417 negated,
418 expr,
419 pattern,
420 escape_char,
421 case_insensitive,
422 ))))
423 }
424 Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
425 let (left, right) =
426 self.coerce_binary_op(*left, self.schema, op, *right, self.schema)?;
427 Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new(
428 Box::new(left),
429 op,
430 Box::new(right),
431 ))))
432 }
433 Expr::Between(Between {
434 expr,
435 negated,
436 low,
437 high,
438 }) => {
439 let expr_type = expr.get_type(self.schema)?;
440 let low_type = low.get_type(self.schema)?;
441 let low_coerced_type = comparison_coercion(&expr_type, &low_type)
442 .ok_or_else(|| {
443 DataFusionError::Internal(format!(
444 "Failed to coerce types {expr_type} and {low_type} in BETWEEN expression"
445 ))
446 })?;
447 let high_type = high.get_type(self.schema)?;
448 let high_coerced_type = comparison_coercion(&expr_type, &high_type)
449 .ok_or_else(|| {
450 DataFusionError::Internal(format!(
451 "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
452 ))
453 })?;
454 let coercion_type =
455 comparison_coercion(&low_coerced_type, &high_coerced_type)
456 .ok_or_else(|| {
457 DataFusionError::Internal(format!(
458 "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
459 ))
460 })?;
461 Ok(Transformed::yes(Expr::Between(Between::new(
462 Box::new(expr.cast_to(&coercion_type, self.schema)?),
463 negated,
464 Box::new(low.cast_to(&coercion_type, self.schema)?),
465 Box::new(high.cast_to(&coercion_type, self.schema)?),
466 ))))
467 }
468 Expr::InList(InList {
469 expr,
470 list,
471 negated,
472 }) => {
473 let expr_data_type = expr.get_type(self.schema)?;
474 let list_data_types = list
475 .iter()
476 .map(|list_expr| list_expr.get_type(self.schema))
477 .collect::<Result<Vec<_>>>()?;
478 let result_type =
479 get_coerce_type_for_list(&expr_data_type, &list_data_types);
480 match result_type {
481 None => plan_err!(
482 "Can not find compatible types to compare {expr_data_type:?} with {list_data_types:?}"
483 ),
484 Some(coerced_type) => {
485 let cast_expr = expr.cast_to(&coerced_type, self.schema)?;
487 let cast_list_expr = list
488 .into_iter()
489 .map(|list_expr| {
490 list_expr.cast_to(&coerced_type, self.schema)
491 })
492 .collect::<Result<Vec<_>>>()?;
493 Ok(Transformed::yes(Expr::InList(InList ::new(
494 Box::new(cast_expr),
495 cast_list_expr,
496 negated,
497 ))))
498 }
499 }
500 }
501 Expr::Case(case) => {
502 let case = coerce_case_expression(case, self.schema)?;
503 Ok(Transformed::yes(Expr::Case(case)))
504 }
505 Expr::ScalarFunction(ScalarFunction { func, args }) => {
506 let new_expr = coerce_arguments_for_signature_with_scalar_udf(
507 args,
508 self.schema,
509 &func,
510 )?;
511 Ok(Transformed::yes(Expr::ScalarFunction(
512 ScalarFunction::new_udf(func, new_expr),
513 )))
514 }
515 Expr::AggregateFunction(expr::AggregateFunction {
516 func,
517 params:
518 AggregateFunctionParams {
519 args,
520 distinct,
521 filter,
522 order_by,
523 null_treatment,
524 },
525 }) => {
526 let new_expr = coerce_arguments_for_signature_with_aggregate_udf(
527 args,
528 self.schema,
529 &func,
530 )?;
531 Ok(Transformed::yes(Expr::AggregateFunction(
532 expr::AggregateFunction::new_udf(
533 func,
534 new_expr,
535 distinct,
536 filter,
537 order_by,
538 null_treatment,
539 ),
540 )))
541 }
542 Expr::WindowFunction(WindowFunction {
543 fun,
544 params:
545 expr::WindowFunctionParams {
546 args,
547 partition_by,
548 order_by,
549 window_frame,
550 null_treatment,
551 },
552 }) => {
553 let window_frame =
554 coerce_window_frame(window_frame, self.schema, &order_by)?;
555
556 let args = match &fun {
557 expr::WindowFunctionDefinition::AggregateUDF(udf) => {
558 coerce_arguments_for_signature_with_aggregate_udf(
559 args,
560 self.schema,
561 udf,
562 )?
563 }
564 _ => args,
565 };
566
567 Ok(Transformed::yes(
568 Expr::WindowFunction(WindowFunction::new(fun, args))
569 .partition_by(partition_by)
570 .order_by(order_by)
571 .window_frame(window_frame)
572 .null_treatment(null_treatment)
573 .build()?,
574 ))
575 }
576 #[expect(deprecated)]
578 Expr::Alias(_)
579 | Expr::Column(_)
580 | Expr::ScalarVariable(_, _)
581 | Expr::Literal(_)
582 | Expr::SimilarTo(_)
583 | Expr::IsNotNull(_)
584 | Expr::IsNull(_)
585 | Expr::Negative(_)
586 | Expr::Cast(_)
587 | Expr::TryCast(_)
588 | Expr::Wildcard { .. }
589 | Expr::GroupingSet(_)
590 | Expr::Placeholder(_)
591 | Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)),
592 }
593 }
594}
595
596fn transform_schema_to_nonview(dfschema: &DFSchemaRef) -> Option<Result<DFSchema>> {
598 let metadata = dfschema.as_arrow().metadata.clone();
599 let mut transformed = false;
600
601 let (qualifiers, transformed_fields): (Vec<Option<TableReference>>, Vec<Arc<Field>>) =
602 dfschema
603 .iter()
604 .map(|(qualifier, field)| match field.data_type() {
605 DataType::Utf8View => {
606 transformed = true;
607 (
608 qualifier.cloned() as Option<TableReference>,
609 Arc::new(Field::new(
610 field.name(),
611 DataType::LargeUtf8,
612 field.is_nullable(),
613 )),
614 )
615 }
616 DataType::BinaryView => {
617 transformed = true;
618 (
619 qualifier.cloned() as Option<TableReference>,
620 Arc::new(Field::new(
621 field.name(),
622 DataType::LargeBinary,
623 field.is_nullable(),
624 )),
625 )
626 }
627 _ => (
628 qualifier.cloned() as Option<TableReference>,
629 Arc::clone(field),
630 ),
631 })
632 .unzip();
633
634 if !transformed {
635 return None;
636 }
637
638 let schema = Schema::new_with_metadata(transformed_fields, metadata);
639 Some(DFSchema::from_field_specific_qualified_schema(
640 qualifiers,
641 &Arc::new(schema),
642 ))
643}
644
645fn coerce_scalar(target_type: &DataType, value: &ScalarValue) -> Result<ScalarValue> {
648 match value {
649 ScalarValue::Utf8(Some(val)) => {
651 ScalarValue::try_from_string(val.clone(), target_type)
652 }
653 s => {
654 if s.is_null() {
655 ScalarValue::try_from(target_type)
657 } else {
658 Ok(s.clone())
662 }
663 }
664 }
665}
666
667fn coerce_scalar_range_aware(
674 target_type: &DataType,
675 value: &ScalarValue,
676) -> Result<ScalarValue> {
677 coerce_scalar(target_type, value).or_else(|err| {
678 if let Some(largest_type) = get_widest_type_in_family(target_type) {
680 coerce_scalar(largest_type, value).map_or_else(
681 |_| exec_err!("Cannot cast {value:?} to {target_type:?}"),
682 |_| ScalarValue::try_from(target_type),
683 )
684 } else {
685 Err(err)
686 }
687 })
688}
689
690fn get_widest_type_in_family(given_type: &DataType) -> Option<&DataType> {
694 match given_type {
695 DataType::UInt8 | DataType::UInt16 | DataType::UInt32 => Some(&DataType::UInt64),
696 DataType::Int8 | DataType::Int16 | DataType::Int32 => Some(&DataType::Int64),
697 DataType::Float16 | DataType::Float32 => Some(&DataType::Float64),
698 _ => None,
699 }
700}
701
702fn coerce_frame_bound(
704 target_type: &DataType,
705 bound: WindowFrameBound,
706) -> Result<WindowFrameBound> {
707 match bound {
708 WindowFrameBound::Preceding(v) => {
709 coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Preceding)
710 }
711 WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow),
712 WindowFrameBound::Following(v) => {
713 coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Following)
714 }
715 }
716}
717
718fn extract_window_frame_target_type(col_type: &DataType) -> Result<DataType> {
719 if col_type.is_numeric()
720 || is_utf8_or_utf8view_or_large_utf8(col_type)
721 || matches!(col_type, DataType::Null)
722 || matches!(col_type, DataType::Boolean)
723 {
724 Ok(col_type.clone())
725 } else if is_datetime(col_type) {
726 Ok(DataType::Interval(IntervalUnit::MonthDayNano))
727 } else if let DataType::Dictionary(_, value_type) = col_type {
728 extract_window_frame_target_type(value_type)
729 } else {
730 return internal_err!("Cannot run range queries on datatype: {col_type:?}");
731 }
732}
733
734fn coerce_window_frame(
737 window_frame: WindowFrame,
738 schema: &DFSchema,
739 expressions: &[Sort],
740) -> Result<WindowFrame> {
741 let mut window_frame = window_frame;
742 let target_type = match window_frame.units {
743 WindowFrameUnits::Range => {
744 let current_types = expressions
745 .first()
746 .map(|s| s.expr.get_type(schema))
747 .transpose()?;
748 if let Some(col_type) = current_types {
749 extract_window_frame_target_type(&col_type)?
750 } else {
751 return internal_err!("ORDER BY column cannot be empty");
752 }
753 }
754 WindowFrameUnits::Rows | WindowFrameUnits::Groups => DataType::UInt64,
755 };
756 window_frame.start_bound =
757 coerce_frame_bound(&target_type, window_frame.start_bound)?;
758 window_frame.end_bound = coerce_frame_bound(&target_type, window_frame.end_bound)?;
759 Ok(window_frame)
760}
761
762fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result<Expr> {
765 let left_type = expr.get_type(schema)?;
766 BinaryTypeCoercer::new(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)
767 .get_input_types()?;
768 expr.cast_to(&DataType::Boolean, schema)
769}
770
771fn coerce_arguments_for_signature_with_scalar_udf(
776 expressions: Vec<Expr>,
777 schema: &DFSchema,
778 func: &ScalarUDF,
779) -> Result<Vec<Expr>> {
780 if expressions.is_empty() {
781 return Ok(expressions);
782 }
783
784 let current_types = expressions
785 .iter()
786 .map(|e| e.get_type(schema))
787 .collect::<Result<Vec<_>>>()?;
788
789 let new_types = data_types_with_scalar_udf(¤t_types, func)?;
790
791 expressions
792 .into_iter()
793 .enumerate()
794 .map(|(i, expr)| expr.cast_to(&new_types[i], schema))
795 .collect()
796}
797
798fn coerce_arguments_for_signature_with_aggregate_udf(
803 expressions: Vec<Expr>,
804 schema: &DFSchema,
805 func: &AggregateUDF,
806) -> Result<Vec<Expr>> {
807 if expressions.is_empty() {
808 return Ok(expressions);
809 }
810
811 let current_types = expressions
812 .iter()
813 .map(|e| e.get_type(schema))
814 .collect::<Result<Vec<_>>>()?;
815
816 let new_types = data_types_with_aggregate_udf(¤t_types, func)?;
817
818 expressions
819 .into_iter()
820 .enumerate()
821 .map(|(i, expr)| expr.cast_to(&new_types[i], schema))
822 .collect()
823}
824
825fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result<Case> {
826 let case_type = case
858 .expr
859 .as_ref()
860 .map(|expr| expr.get_type(schema))
861 .transpose()?;
862 let then_types = case
863 .when_then_expr
864 .iter()
865 .map(|(_when, then)| then.get_type(schema))
866 .collect::<Result<Vec<_>>>()?;
867 let else_type = case
868 .else_expr
869 .as_ref()
870 .map(|expr| expr.get_type(schema))
871 .transpose()?;
872
873 let case_when_coerce_type = case_type
875 .as_ref()
876 .map(|case_type| {
877 let when_types = case
878 .when_then_expr
879 .iter()
880 .map(|(when, _then)| when.get_type(schema))
881 .collect::<Result<Vec<_>>>()?;
882 let coerced_type =
883 get_coerce_type_for_case_expression(&when_types, Some(case_type));
884 coerced_type.ok_or_else(|| {
885 plan_datafusion_err!(
886 "Failed to coerce case ({case_type:?}) and when ({when_types:?}) \
887 to common types in CASE WHEN expression"
888 )
889 })
890 })
891 .transpose()?;
892 let then_else_coerce_type =
893 get_coerce_type_for_case_expression(&then_types, else_type.as_ref()).ok_or_else(
894 || {
895 plan_datafusion_err!(
896 "Failed to coerce then ({then_types:?}) and else ({else_type:?}) \
897 to common types in CASE WHEN expression"
898 )
899 },
900 )?;
901
902 let case_expr = case
904 .expr
905 .zip(case_when_coerce_type.as_ref())
906 .map(|(case_expr, coercible_type)| case_expr.cast_to(coercible_type, schema))
907 .transpose()?
908 .map(Box::new);
909 let when_then = case
910 .when_then_expr
911 .into_iter()
912 .map(|(when, then)| {
913 let when_type = case_when_coerce_type.as_ref().unwrap_or(&DataType::Boolean);
914 let when = when.cast_to(when_type, schema).map_err(|e| {
915 DataFusionError::Context(
916 format!(
917 "WHEN expressions in CASE couldn't be \
918 converted to common type ({when_type})"
919 ),
920 Box::new(e),
921 )
922 })?;
923 let then = then.cast_to(&then_else_coerce_type, schema)?;
924 Ok((Box::new(when), Box::new(then)))
925 })
926 .collect::<Result<Vec<_>>>()?;
927 let else_expr = case
928 .else_expr
929 .map(|expr| expr.cast_to(&then_else_coerce_type, schema))
930 .transpose()?
931 .map(Box::new);
932
933 Ok(Case::new(case_expr, when_then, else_expr))
934}
935
936pub fn coerce_union_schema(inputs: &[Arc<LogicalPlan>]) -> Result<DFSchema> {
941 coerce_union_schema_with_schema(&inputs[1..], inputs[0].schema())
942}
943fn coerce_union_schema_with_schema(
944 inputs: &[Arc<LogicalPlan>],
945 base_schema: &DFSchemaRef,
946) -> Result<DFSchema> {
947 let mut union_datatypes = base_schema
948 .fields()
949 .iter()
950 .map(|f| f.data_type().clone())
951 .collect::<Vec<_>>();
952 let mut union_nullabilities = base_schema
953 .fields()
954 .iter()
955 .map(|f| f.is_nullable())
956 .collect::<Vec<_>>();
957 let mut union_field_meta = base_schema
958 .fields()
959 .iter()
960 .map(|f| f.metadata().clone())
961 .collect::<Vec<_>>();
962
963 let mut metadata = base_schema.metadata().clone();
964
965 for (i, plan) in inputs.iter().enumerate() {
966 let plan_schema = plan.schema();
967 metadata.extend(plan_schema.metadata().clone());
968
969 if plan_schema.fields().len() != base_schema.fields().len() {
970 return plan_err!(
971 "Union schemas have different number of fields: \
972 query 1 has {} fields whereas query {} has {} fields",
973 base_schema.fields().len(),
974 i + 1,
975 plan_schema.fields().len()
976 );
977 }
978
979 for (union_datatype, union_nullable, union_field_map, plan_field) in izip!(
981 union_datatypes.iter_mut(),
982 union_nullabilities.iter_mut(),
983 union_field_meta.iter_mut(),
984 plan_schema.fields().iter()
985 ) {
986 let coerced_type =
987 comparison_coercion(union_datatype, plan_field.data_type()).ok_or_else(
988 || {
989 plan_datafusion_err!(
990 "Incompatible inputs for Union: Previous inputs were \
991 of type {}, but got incompatible type {} on column '{}'",
992 union_datatype,
993 plan_field.data_type(),
994 plan_field.name()
995 )
996 },
997 )?;
998
999 *union_datatype = coerced_type;
1000 *union_nullable = *union_nullable || plan_field.is_nullable();
1001 union_field_map.extend(plan_field.metadata().clone());
1002 }
1003 }
1004 let union_qualified_fields = izip!(
1005 base_schema.fields(),
1006 union_datatypes.into_iter(),
1007 union_nullabilities,
1008 union_field_meta.into_iter()
1009 )
1010 .map(|(field, datatype, nullable, metadata)| {
1011 let mut field = Field::new(field.name().clone(), datatype, nullable);
1012 field.set_metadata(metadata);
1013 (None, field.into())
1014 })
1015 .collect::<Vec<_>>();
1016
1017 DFSchema::new_with_metadata(union_qualified_fields, metadata)
1018}
1019
1020fn project_with_column_index(
1022 expr: Vec<Expr>,
1023 input: Arc<LogicalPlan>,
1024 schema: DFSchemaRef,
1025) -> Result<LogicalPlan> {
1026 let alias_expr = expr
1027 .into_iter()
1028 .enumerate()
1029 .map(|(i, e)| match e {
1030 Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => {
1031 Ok(e.unalias().alias(schema.field(i).name()))
1032 }
1033 Expr::Column(Column {
1034 relation: _,
1035 ref name,
1036 spans: _,
1037 }) if name != schema.field(i).name() => Ok(e.alias(schema.field(i).name())),
1038 Expr::Alias { .. } | Expr::Column { .. } => Ok(e),
1039 #[expect(deprecated)]
1040 Expr::Wildcard { .. } => {
1041 plan_err!("Wildcard should be expanded before type coercion")
1042 }
1043 _ => Ok(e.alias(schema.field(i).name())),
1044 })
1045 .collect::<Result<Vec<_>>>()?;
1046
1047 Projection::try_new_with_schema(alias_expr, input, schema)
1048 .map(LogicalPlan::Projection)
1049}
1050
1051#[cfg(test)]
1052mod test {
1053 use std::any::Any;
1054 use std::sync::Arc;
1055
1056 use arrow::datatypes::DataType::Utf8;
1057 use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder, TimeUnit};
1058
1059 use crate::analyzer::type_coercion::{
1060 coerce_case_expression, TypeCoercion, TypeCoercionRewriter,
1061 };
1062 use crate::analyzer::Analyzer;
1063 use crate::test::{assert_analyzed_plan_eq, assert_analyzed_plan_with_config_eq};
1064 use datafusion_common::config::ConfigOptions;
1065 use datafusion_common::tree_node::{TransformedResult, TreeNode};
1066 use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue, Spans};
1067 use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction};
1068 use datafusion_expr::logical_plan::{EmptyRelation, Projection, Sort};
1069 use datafusion_expr::test::function_stub::avg_udaf;
1070 use datafusion_expr::{
1071 cast, col, create_udaf, is_true, lit, AccumulatorFactoryFunction, AggregateUDF,
1072 BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Filter, LogicalPlan,
1073 Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
1074 SimpleAggregateUDF, Subquery, Union, Volatility,
1075 };
1076 use datafusion_functions_aggregate::average::AvgAccumulator;
1077 use datafusion_sql::TableReference;
1078
1079 fn empty() -> Arc<LogicalPlan> {
1080 Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1081 produce_one_row: false,
1082 schema: Arc::new(DFSchema::empty()),
1083 }))
1084 }
1085
1086 fn empty_with_type(data_type: DataType) -> Arc<LogicalPlan> {
1087 Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1088 produce_one_row: false,
1089 schema: Arc::new(
1090 DFSchema::from_unqualified_fields(
1091 vec![Field::new("a", data_type, true)].into(),
1092 std::collections::HashMap::new(),
1093 )
1094 .unwrap(),
1095 ),
1096 }))
1097 }
1098
1099 #[test]
1100 fn simple_case() -> Result<()> {
1101 let expr = col("a").lt(lit(2_u32));
1102 let empty = empty_with_type(DataType::Float64);
1103 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1104 let expected = "Projection: a < CAST(UInt32(2) AS Float64)\n EmptyRelation";
1105 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)
1106 }
1107
1108 #[test]
1109 fn test_coerce_union() -> Result<()> {
1110 let left_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1111 produce_one_row: false,
1112 schema: Arc::new(
1113 DFSchema::try_from_qualified_schema(
1114 TableReference::full("datafusion", "test", "foo"),
1115 &Schema::new(vec![Field::new("a", DataType::Int32, false)]),
1116 )
1117 .unwrap(),
1118 ),
1119 }));
1120 let right_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1121 produce_one_row: false,
1122 schema: Arc::new(
1123 DFSchema::try_from_qualified_schema(
1124 TableReference::full("datafusion", "test", "foo"),
1125 &Schema::new(vec![Field::new("a", DataType::Int64, false)]),
1126 )
1127 .unwrap(),
1128 ),
1129 }));
1130 let union = LogicalPlan::Union(Union::try_new_with_loose_types(vec![
1131 left_plan, right_plan,
1132 ])?);
1133 let analyzed_union = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())])
1134 .execute_and_check(union, &ConfigOptions::default(), |_, _| {})?;
1135 let top_level_plan = LogicalPlan::Projection(Projection::try_new(
1136 vec![col("a")],
1137 Arc::new(analyzed_union),
1138 )?);
1139
1140 let expected = "Projection: a\n Union\n Projection: CAST(datafusion.test.foo.a AS Int64) AS a\n EmptyRelation\n EmptyRelation";
1141 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), top_level_plan, expected)
1142 }
1143
1144 fn coerce_on_output_if_viewtype(plan: LogicalPlan, expected: &str) -> Result<()> {
1145 let mut options = ConfigOptions::default();
1146 options.optimizer.expand_views_at_output = true;
1147
1148 assert_analyzed_plan_with_config_eq(
1149 options,
1150 Arc::new(TypeCoercion::new()),
1151 plan.clone(),
1152 expected,
1153 )
1154 }
1155
1156 fn do_not_coerce_on_output(plan: LogicalPlan, expected: &str) -> Result<()> {
1157 assert_analyzed_plan_with_config_eq(
1158 ConfigOptions::default(),
1159 Arc::new(TypeCoercion::new()),
1160 plan.clone(),
1161 expected,
1162 )
1163 }
1164
1165 #[test]
1166 fn coerce_utf8view_output() -> Result<()> {
1167 let expr = col("a");
1170 let empty = empty_with_type(DataType::Utf8View);
1171 let plan = LogicalPlan::Projection(Projection::try_new(
1172 vec![expr.clone()],
1173 Arc::clone(&empty),
1174 )?);
1175 let if_not_coerced = "Projection: a\n EmptyRelation";
1177 do_not_coerce_on_output(plan.clone(), if_not_coerced)?;
1178 let if_coerced = "Projection: CAST(a AS LargeUtf8)\n EmptyRelation";
1180 coerce_on_output_if_viewtype(plan.clone(), if_coerced)?;
1181
1182 let bool_expr = col("a").lt(lit("foo"));
1185 let bool_plan = LogicalPlan::Projection(Projection::try_new(
1186 vec![bool_expr],
1187 Arc::clone(&empty),
1188 )?);
1189 let if_not_coerced =
1191 "Projection: a < CAST(Utf8(\"foo\") AS Utf8View)\n EmptyRelation";
1192 do_not_coerce_on_output(bool_plan.clone(), if_not_coerced)?;
1193 let if_coerced = if_not_coerced;
1195 coerce_on_output_if_viewtype(bool_plan, if_coerced)?;
1196
1197 let sort_expr = expr.sort(true, true);
1200 let sort_plan = LogicalPlan::Sort(Sort {
1201 expr: vec![sort_expr],
1202 input: Arc::new(plan),
1203 fetch: None,
1204 });
1205 let if_not_coerced =
1207 "Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation";
1208 do_not_coerce_on_output(sort_plan.clone(), if_not_coerced)?;
1209 let if_coerced = "Projection: CAST(a AS LargeUtf8)\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation";
1211 coerce_on_output_if_viewtype(sort_plan.clone(), if_coerced)?;
1212
1213 let plan = LogicalPlan::Projection(Projection::try_new(
1216 vec![col("a")],
1217 Arc::new(sort_plan),
1218 )?);
1219 let if_not_coerced = "Projection: a\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation";
1221 do_not_coerce_on_output(plan.clone(), if_not_coerced)?;
1222 let if_coerced = "Projection: CAST(a AS LargeUtf8)\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation";
1224 coerce_on_output_if_viewtype(plan.clone(), if_coerced)?;
1225
1226 Ok(())
1227 }
1228
1229 #[test]
1230 fn coerce_binaryview_output() -> Result<()> {
1231 let expr = col("a");
1234 let empty = empty_with_type(DataType::BinaryView);
1235 let plan = LogicalPlan::Projection(Projection::try_new(
1236 vec![expr.clone()],
1237 Arc::clone(&empty),
1238 )?);
1239 let if_not_coerced = "Projection: a\n EmptyRelation";
1241 do_not_coerce_on_output(plan.clone(), if_not_coerced)?;
1242 let if_coerced = "Projection: CAST(a AS LargeBinary)\n EmptyRelation";
1244 coerce_on_output_if_viewtype(plan.clone(), if_coerced)?;
1245
1246 let bool_expr = col("a").lt(lit(vec![8, 1, 8, 1]));
1249 let bool_plan = LogicalPlan::Projection(Projection::try_new(
1250 vec![bool_expr],
1251 Arc::clone(&empty),
1252 )?);
1253 let if_not_coerced =
1255 "Projection: a < CAST(Binary(\"8,1,8,1\") AS BinaryView)\n EmptyRelation";
1256 do_not_coerce_on_output(bool_plan.clone(), if_not_coerced)?;
1257 let if_coerced = if_not_coerced;
1259 coerce_on_output_if_viewtype(bool_plan, if_coerced)?;
1260
1261 let sort_expr = expr.sort(true, true);
1264 let sort_plan = LogicalPlan::Sort(Sort {
1265 expr: vec![sort_expr],
1266 input: Arc::new(plan),
1267 fetch: None,
1268 });
1269 let if_not_coerced =
1271 "Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation";
1272 do_not_coerce_on_output(sort_plan.clone(), if_not_coerced)?;
1273 let if_coerced = "Projection: CAST(a AS LargeBinary)\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation";
1275 coerce_on_output_if_viewtype(sort_plan.clone(), if_coerced)?;
1276
1277 let plan = LogicalPlan::Projection(Projection::try_new(
1280 vec![col("a")],
1281 Arc::new(sort_plan),
1282 )?);
1283 let if_not_coerced = "Projection: a\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation";
1285 do_not_coerce_on_output(plan.clone(), if_not_coerced)?;
1286 let if_coerced = "Projection: CAST(a AS LargeBinary)\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation";
1288 coerce_on_output_if_viewtype(plan.clone(), if_coerced)?;
1289
1290 Ok(())
1291 }
1292
1293 #[test]
1294 fn nested_case() -> Result<()> {
1295 let expr = col("a").lt(lit(2_u32));
1296 let empty = empty_with_type(DataType::Float64);
1297
1298 let plan = LogicalPlan::Projection(Projection::try_new(
1299 vec![expr.clone().or(expr)],
1300 empty,
1301 )?);
1302 let expected = "Projection: a < CAST(UInt32(2) AS Float64) OR a < CAST(UInt32(2) AS Float64)\
1303 \n EmptyRelation";
1304 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)
1305 }
1306
1307 #[derive(Debug, Clone)]
1308 struct TestScalarUDF {
1309 signature: Signature,
1310 }
1311
1312 impl ScalarUDFImpl for TestScalarUDF {
1313 fn as_any(&self) -> &dyn Any {
1314 self
1315 }
1316
1317 fn name(&self) -> &str {
1318 "TestScalarUDF"
1319 }
1320
1321 fn signature(&self) -> &Signature {
1322 &self.signature
1323 }
1324
1325 fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1326 Ok(Utf8)
1327 }
1328
1329 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1330 Ok(ColumnarValue::Scalar(ScalarValue::from("a")))
1331 }
1332 }
1333
1334 #[test]
1335 fn scalar_udf() -> Result<()> {
1336 let empty = empty();
1337
1338 let udf = ScalarUDF::from(TestScalarUDF {
1339 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1340 })
1341 .call(vec![lit(123_i32)]);
1342 let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?);
1343 let expected =
1344 "Projection: TestScalarUDF(CAST(Int32(123) AS Float32))\n EmptyRelation";
1345 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)
1346 }
1347
1348 #[test]
1349 fn scalar_udf_invalid_input() -> Result<()> {
1350 let empty = empty();
1351 let udf = ScalarUDF::from(TestScalarUDF {
1352 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1353 })
1354 .call(vec![lit("Apple")]);
1355 Projection::try_new(vec![udf], empty)
1356 .expect_err("Expected an error due to incorrect function input");
1357
1358 Ok(())
1359 }
1360
1361 #[test]
1362 fn scalar_function() -> Result<()> {
1363 let empty = empty();
1365 let lit_expr = lit(10i64);
1366 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
1367 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1368 });
1369 let scalar_function_expr =
1370 Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![lit_expr]));
1371 let plan = LogicalPlan::Projection(Projection::try_new(
1372 vec![scalar_function_expr],
1373 empty,
1374 )?);
1375 let expected =
1376 "Projection: TestScalarUDF(CAST(Int64(10) AS Float32))\n EmptyRelation";
1377 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)
1378 }
1379
1380 #[test]
1381 fn agg_udaf() -> Result<()> {
1382 let empty = empty();
1383 let my_avg = create_udaf(
1384 "MY_AVG",
1385 vec![DataType::Float64],
1386 Arc::new(DataType::Float64),
1387 Volatility::Immutable,
1388 Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
1389 Arc::new(vec![DataType::UInt64, DataType::Float64]),
1390 );
1391 let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1392 Arc::new(my_avg),
1393 vec![lit(10i64)],
1394 false,
1395 None,
1396 None,
1397 None,
1398 ));
1399 let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?);
1400 let expected = "Projection: MY_AVG(CAST(Int64(10) AS Float64))\n EmptyRelation";
1401 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)
1402 }
1403
1404 #[test]
1405 fn agg_udaf_invalid_input() -> Result<()> {
1406 let empty = empty();
1407 let return_type = DataType::Float64;
1408 let accumulator: AccumulatorFactoryFunction =
1409 Arc::new(|_| Ok(Box::<AvgAccumulator>::default()));
1410 let my_avg = AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
1411 "MY_AVG",
1412 Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
1413 return_type,
1414 accumulator,
1415 vec![
1416 Field::new("count", DataType::UInt64, true),
1417 Field::new("avg", DataType::Float64, true),
1418 ],
1419 ));
1420 let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1421 Arc::new(my_avg),
1422 vec![lit("10")],
1423 false,
1424 None,
1425 None,
1426 None,
1427 ));
1428
1429 let err = Projection::try_new(vec![udaf], empty).err().unwrap();
1430 assert!(
1431 err.strip_backtrace().starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'MY_AVG' function: coercion from [Utf8] to the signature Uniform(1, [Float64]) failed")
1432 );
1433 Ok(())
1434 }
1435
1436 #[test]
1437 fn agg_function_case() -> Result<()> {
1438 let empty = empty();
1439 let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1440 avg_udaf(),
1441 vec![lit(12f64)],
1442 false,
1443 None,
1444 None,
1445 None,
1446 ));
1447 let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1448 let expected = "Projection: avg(Float64(12))\n EmptyRelation";
1449 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1450
1451 let empty = empty_with_type(DataType::Int32);
1452 let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1453 avg_udaf(),
1454 vec![cast(col("a"), DataType::Float64)],
1455 false,
1456 None,
1457 None,
1458 None,
1459 ));
1460 let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1461 let expected = "Projection: avg(CAST(a AS Float64))\n EmptyRelation";
1462 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1463 Ok(())
1464 }
1465
1466 #[test]
1467 fn agg_function_invalid_input_avg() -> Result<()> {
1468 let empty = empty();
1469 let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1470 avg_udaf(),
1471 vec![lit("1")],
1472 false,
1473 None,
1474 None,
1475 None,
1476 ));
1477 let err = Projection::try_new(vec![agg_expr], empty)
1478 .err()
1479 .unwrap()
1480 .strip_backtrace();
1481 assert!(err.starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'avg' function: coercion from [Utf8] to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed"));
1482 Ok(())
1483 }
1484
1485 #[test]
1486 fn binary_op_date32_op_interval() -> Result<()> {
1487 let expr = cast(lit("1998-03-18"), DataType::Date32)
1489 + lit(ScalarValue::new_interval_dt(123, 456));
1490 let empty = empty();
1491 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1492 let expected =
1493 "Projection: CAST(Utf8(\"1998-03-18\") AS Date32) + IntervalDayTime(\"IntervalDayTime { days: 123, milliseconds: 456 }\")\n EmptyRelation";
1494 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1495 Ok(())
1496 }
1497
1498 #[test]
1499 fn inlist_case() -> Result<()> {
1500 let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1502 let empty = empty_with_type(DataType::Int64);
1503 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1504 let expected = "Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)])\n EmptyRelation";
1505 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1506
1507 let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1509 let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1510 produce_one_row: false,
1511 schema: Arc::new(DFSchema::from_unqualified_fields(
1512 vec![Field::new("a", DataType::Decimal128(12, 4), true)].into(),
1513 std::collections::HashMap::new(),
1514 )?),
1515 }));
1516 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1517 let expected = "Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))])\n EmptyRelation";
1518 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)
1519 }
1520
1521 #[test]
1522 fn between_case() -> Result<()> {
1523 let expr = col("a").between(
1524 lit("2002-05-08"),
1525 cast(lit("2002-05-08"), DataType::Date32)
1527 + lit(ScalarValue::new_interval_ym(0, 1)),
1528 );
1529 let empty = empty_with_type(Utf8);
1530 let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1531 let expected =
1532 "Filter: CAST(a AS Date32) BETWEEN CAST(Utf8(\"2002-05-08\") AS Date32) AND CAST(Utf8(\"2002-05-08\") AS Date32) + IntervalYearMonth(\"1\")\
1533 \n EmptyRelation";
1534 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)
1535 }
1536
1537 #[test]
1538 fn between_infer_cheap_type() -> Result<()> {
1539 let expr = col("a").between(
1540 cast(lit("2002-05-08"), DataType::Date32)
1542 + lit(ScalarValue::new_interval_ym(0, 1)),
1543 lit("2002-12-08"),
1544 );
1545 let empty = empty_with_type(Utf8);
1546 let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1547 let expected =
1549 "Filter: CAST(a AS Date32) BETWEEN CAST(Utf8(\"2002-05-08\") AS Date32) + IntervalYearMonth(\"1\") AND CAST(Utf8(\"2002-12-08\") AS Date32)\
1550 \n EmptyRelation";
1551 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)
1552 }
1553
1554 #[test]
1555 fn between_null() -> Result<()> {
1556 let expr = lit(ScalarValue::Null).between(lit(ScalarValue::Null), lit(2i64));
1557 let empty = empty();
1558 let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1559 let expected =
1560 "Filter: CAST(NULL AS Int64) BETWEEN CAST(NULL AS Int64) AND Int64(2)\
1561 \n EmptyRelation";
1562 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)
1563 }
1564
1565 #[test]
1566 fn is_bool_for_type_coercion() -> Result<()> {
1567 let expr = col("a").is_true();
1569 let empty = empty_with_type(DataType::Boolean);
1570 let plan =
1571 LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
1572 let expected = "Projection: a IS TRUE\n EmptyRelation";
1573 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1574
1575 let empty = empty_with_type(DataType::Int64);
1576 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1577 let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, "");
1578 let err = ret.unwrap_err().to_string();
1579 assert!(err.contains("Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean"), "{err}");
1580
1581 let expr = col("a").is_not_true();
1583 let empty = empty_with_type(DataType::Boolean);
1584 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1585 let expected = "Projection: a IS NOT TRUE\n EmptyRelation";
1586 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1587
1588 let expr = col("a").is_false();
1590 let empty = empty_with_type(DataType::Boolean);
1591 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1592 let expected = "Projection: a IS FALSE\n EmptyRelation";
1593 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1594
1595 let expr = col("a").is_not_false();
1597 let empty = empty_with_type(DataType::Boolean);
1598 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1599 let expected = "Projection: a IS NOT FALSE\n EmptyRelation";
1600 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1601
1602 Ok(())
1603 }
1604
1605 #[test]
1606 fn like_for_type_coercion() -> Result<()> {
1607 let expr = Box::new(col("a"));
1609 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1610 let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1611 let empty = empty_with_type(Utf8);
1612 let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1613 let expected = "Projection: a LIKE Utf8(\"abc\")\n EmptyRelation";
1614 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1615
1616 let expr = Box::new(col("a"));
1617 let pattern = Box::new(lit(ScalarValue::Null));
1618 let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1619 let empty = empty_with_type(Utf8);
1620 let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1621 let expected = "Projection: a LIKE CAST(NULL AS Utf8)\n EmptyRelation";
1622 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1623
1624 let expr = Box::new(col("a"));
1625 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1626 let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1627 let empty = empty_with_type(DataType::Int64);
1628 let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1629 let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected);
1630 assert!(err.is_err());
1631 assert!(err.unwrap_err().to_string().contains(
1632 "There isn't a common type to coerce Int64 and Utf8 in LIKE expression"
1633 ));
1634
1635 let expr = Box::new(col("a"));
1637 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1638 let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
1639 let empty = empty_with_type(Utf8);
1640 let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
1641 let expected = "Projection: a ILIKE Utf8(\"abc\")\n EmptyRelation";
1642 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1643
1644 let expr = Box::new(col("a"));
1645 let pattern = Box::new(lit(ScalarValue::Null));
1646 let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
1647 let empty = empty_with_type(Utf8);
1648 let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
1649 let expected = "Projection: a ILIKE CAST(NULL AS Utf8)\n EmptyRelation";
1650 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1651
1652 let expr = Box::new(col("a"));
1653 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1654 let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
1655 let empty = empty_with_type(DataType::Int64);
1656 let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
1657 let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected);
1658 assert!(err.is_err());
1659 assert!(err.unwrap_err().to_string().contains(
1660 "There isn't a common type to coerce Int64 and Utf8 in ILIKE expression"
1661 ));
1662 Ok(())
1663 }
1664
1665 #[test]
1666 fn unknown_for_type_coercion() -> Result<()> {
1667 let expr = col("a").is_unknown();
1669 let empty = empty_with_type(DataType::Boolean);
1670 let plan =
1671 LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
1672 let expected = "Projection: a IS UNKNOWN\n EmptyRelation";
1673 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1674
1675 let empty = empty_with_type(Utf8);
1676 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1677 let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected);
1678 let err = ret.unwrap_err().to_string();
1679 assert!(err.contains("Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean"), "{err}");
1680
1681 let expr = col("a").is_not_unknown();
1683 let empty = empty_with_type(DataType::Boolean);
1684 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1685 let expected = "Projection: a IS NOT UNKNOWN\n EmptyRelation";
1686 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1687
1688 Ok(())
1689 }
1690
1691 #[test]
1692 fn concat_for_type_coercion() -> Result<()> {
1693 let empty = empty_with_type(Utf8);
1694 let args = [col("a"), lit("b"), lit(true), lit(false), lit(13)];
1695
1696 {
1698 let expr = ScalarUDF::new_from_impl(TestScalarUDF {
1699 signature: Signature::variadic(vec![Utf8], Volatility::Immutable),
1700 })
1701 .call(args.to_vec());
1702 let plan = LogicalPlan::Projection(Projection::try_new(
1703 vec![expr],
1704 Arc::clone(&empty),
1705 )?);
1706 let expected =
1707 "Projection: TestScalarUDF(a, Utf8(\"b\"), CAST(Boolean(true) AS Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS Utf8))\n EmptyRelation";
1708 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1709 }
1710
1711 Ok(())
1712 }
1713
1714 #[test]
1715 fn test_type_coercion_rewrite() -> Result<()> {
1716 let schema = Arc::new(DFSchema::from_unqualified_fields(
1718 vec![Field::new("a", DataType::Int64, true)].into(),
1719 std::collections::HashMap::new(),
1720 )?);
1721 let mut rewriter = TypeCoercionRewriter { schema: &schema };
1722 let expr = is_true(lit(12i32).gt(lit(13i64)));
1723 let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64)));
1724 let result = expr.rewrite(&mut rewriter).data()?;
1725 assert_eq!(expected, result);
1726
1727 let schema = Arc::new(DFSchema::from_unqualified_fields(
1729 vec![Field::new("a", DataType::Int64, true)].into(),
1730 std::collections::HashMap::new(),
1731 )?);
1732 let mut rewriter = TypeCoercionRewriter { schema: &schema };
1733 let expr = is_true(lit(12i32).eq(lit(13i64)));
1734 let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64)));
1735 let result = expr.rewrite(&mut rewriter).data()?;
1736 assert_eq!(expected, result);
1737
1738 let schema = Arc::new(DFSchema::from_unqualified_fields(
1740 vec![Field::new("a", DataType::Int64, true)].into(),
1741 std::collections::HashMap::new(),
1742 )?);
1743 let mut rewriter = TypeCoercionRewriter { schema: &schema };
1744 let expr = is_true(lit(12i32).lt(lit(13i64)));
1745 let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64)));
1746 let result = expr.rewrite(&mut rewriter).data()?;
1747 assert_eq!(expected, result);
1748
1749 Ok(())
1750 }
1751
1752 #[test]
1753 fn binary_op_date32_eq_ts() -> Result<()> {
1754 let expr = cast(
1755 lit("1998-03-18"),
1756 DataType::Timestamp(TimeUnit::Nanosecond, None),
1757 )
1758 .eq(cast(lit("1998-03-18"), DataType::Date32));
1759 let empty = empty();
1760 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1761 let expected =
1762 "Projection: CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None)) = CAST(CAST(Utf8(\"1998-03-18\") AS Date32) AS Timestamp(Nanosecond, None))\n EmptyRelation";
1763 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1764 Ok(())
1765 }
1766
1767 fn cast_if_not_same_type(
1768 expr: Box<Expr>,
1769 data_type: &DataType,
1770 schema: &DFSchemaRef,
1771 ) -> Box<Expr> {
1772 if &expr.get_type(schema).unwrap() != data_type {
1773 Box::new(cast(*expr, data_type.clone()))
1774 } else {
1775 expr
1776 }
1777 }
1778
1779 fn cast_helper(
1780 case: Case,
1781 case_when_type: &DataType,
1782 then_else_type: &DataType,
1783 schema: &DFSchemaRef,
1784 ) -> Case {
1785 let expr = case
1786 .expr
1787 .map(|e| cast_if_not_same_type(e, case_when_type, schema));
1788 let when_then_expr = case
1789 .when_then_expr
1790 .into_iter()
1791 .map(|(when, then)| {
1792 (
1793 cast_if_not_same_type(when, case_when_type, schema),
1794 cast_if_not_same_type(then, then_else_type, schema),
1795 )
1796 })
1797 .collect::<Vec<_>>();
1798 let else_expr = case
1799 .else_expr
1800 .map(|e| cast_if_not_same_type(e, then_else_type, schema));
1801
1802 Case {
1803 expr,
1804 when_then_expr,
1805 else_expr,
1806 }
1807 }
1808
1809 #[test]
1810 fn test_case_expression_coercion() -> Result<()> {
1811 let schema = Arc::new(DFSchema::from_unqualified_fields(
1812 vec![
1813 Field::new("boolean", DataType::Boolean, true),
1814 Field::new("integer", DataType::Int32, true),
1815 Field::new("float", DataType::Float32, true),
1816 Field::new(
1817 "timestamp",
1818 DataType::Timestamp(TimeUnit::Nanosecond, None),
1819 true,
1820 ),
1821 Field::new("date", DataType::Date32, true),
1822 Field::new(
1823 "interval",
1824 DataType::Interval(arrow::datatypes::IntervalUnit::MonthDayNano),
1825 true,
1826 ),
1827 Field::new("binary", DataType::Binary, true),
1828 Field::new("string", Utf8, true),
1829 Field::new("decimal", DataType::Decimal128(10, 10), true),
1830 ]
1831 .into(),
1832 std::collections::HashMap::new(),
1833 )?);
1834
1835 let case = Case {
1836 expr: None,
1837 when_then_expr: vec![
1838 (Box::new(col("boolean")), Box::new(col("integer"))),
1839 (Box::new(col("integer")), Box::new(col("float"))),
1840 (Box::new(col("string")), Box::new(col("string"))),
1841 ],
1842 else_expr: None,
1843 };
1844 let case_when_common_type = DataType::Boolean;
1845 let then_else_common_type = Utf8;
1846 let expected = cast_helper(
1847 case.clone(),
1848 &case_when_common_type,
1849 &then_else_common_type,
1850 &schema,
1851 );
1852 let actual = coerce_case_expression(case, &schema)?;
1853 assert_eq!(expected, actual);
1854
1855 let case = Case {
1856 expr: Some(Box::new(col("string"))),
1857 when_then_expr: vec![
1858 (Box::new(col("float")), Box::new(col("integer"))),
1859 (Box::new(col("integer")), Box::new(col("float"))),
1860 (Box::new(col("string")), Box::new(col("string"))),
1861 ],
1862 else_expr: Some(Box::new(col("string"))),
1863 };
1864 let case_when_common_type = Utf8;
1865 let then_else_common_type = Utf8;
1866 let expected = cast_helper(
1867 case.clone(),
1868 &case_when_common_type,
1869 &then_else_common_type,
1870 &schema,
1871 );
1872 let actual = coerce_case_expression(case, &schema)?;
1873 assert_eq!(expected, actual);
1874
1875 let case = Case {
1876 expr: Some(Box::new(col("interval"))),
1877 when_then_expr: vec![
1878 (Box::new(col("float")), Box::new(col("integer"))),
1879 (Box::new(col("binary")), Box::new(col("float"))),
1880 (Box::new(col("string")), Box::new(col("string"))),
1881 ],
1882 else_expr: Some(Box::new(col("string"))),
1883 };
1884 let err = coerce_case_expression(case, &schema).unwrap_err();
1885 assert_eq!(
1886 err.strip_backtrace(),
1887 "Error during planning: \
1888 Failed to coerce case (Interval(MonthDayNano)) and \
1889 when ([Float32, Binary, Utf8]) to common types in \
1890 CASE WHEN expression"
1891 );
1892
1893 let case = Case {
1894 expr: Some(Box::new(col("string"))),
1895 when_then_expr: vec![
1896 (Box::new(col("float")), Box::new(col("date"))),
1897 (Box::new(col("string")), Box::new(col("float"))),
1898 (Box::new(col("string")), Box::new(col("binary"))),
1899 ],
1900 else_expr: Some(Box::new(col("timestamp"))),
1901 };
1902 let err = coerce_case_expression(case, &schema).unwrap_err();
1903 assert_eq!(
1904 err.strip_backtrace(),
1905 "Error during planning: \
1906 Failed to coerce then ([Date32, Float32, Binary]) and \
1907 else (Some(Timestamp(Nanosecond, None))) to common types \
1908 in CASE WHEN expression"
1909 );
1910
1911 Ok(())
1912 }
1913
1914 macro_rules! test_case_expression {
1915 ($expr:expr, $when_then:expr, $case_when_type:expr, $then_else_type:expr, $schema:expr) => {
1916 let case = Case {
1917 expr: $expr.map(|e| Box::new(col(e))),
1918 when_then_expr: $when_then,
1919 else_expr: None,
1920 };
1921
1922 let expected =
1923 cast_helper(case.clone(), &$case_when_type, &$then_else_type, &$schema);
1924
1925 let actual = coerce_case_expression(case, &$schema)?;
1926 assert_eq!(expected, actual);
1927 };
1928 }
1929
1930 #[test]
1931 fn tes_case_when_list() -> Result<()> {
1932 let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
1933 let schema = Arc::new(DFSchema::from_unqualified_fields(
1934 vec![
1935 Field::new(
1936 "large_list",
1937 DataType::LargeList(Arc::clone(&inner_field)),
1938 true,
1939 ),
1940 Field::new(
1941 "fixed_list",
1942 DataType::FixedSizeList(Arc::clone(&inner_field), 3),
1943 true,
1944 ),
1945 Field::new("list", DataType::List(inner_field), true),
1946 ]
1947 .into(),
1948 std::collections::HashMap::new(),
1949 )?);
1950
1951 test_case_expression!(
1952 Some("list"),
1953 vec![(Box::new(col("large_list")), Box::new(lit("1")))],
1954 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
1955 Utf8,
1956 schema
1957 );
1958
1959 test_case_expression!(
1960 Some("large_list"),
1961 vec![(Box::new(col("list")), Box::new(lit("1")))],
1962 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
1963 Utf8,
1964 schema
1965 );
1966
1967 test_case_expression!(
1968 Some("list"),
1969 vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
1970 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
1971 Utf8,
1972 schema
1973 );
1974
1975 test_case_expression!(
1976 Some("fixed_list"),
1977 vec![(Box::new(col("list")), Box::new(lit("1")))],
1978 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
1979 Utf8,
1980 schema
1981 );
1982
1983 test_case_expression!(
1984 Some("fixed_list"),
1985 vec![(Box::new(col("large_list")), Box::new(lit("1")))],
1986 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
1987 Utf8,
1988 schema
1989 );
1990
1991 test_case_expression!(
1992 Some("large_list"),
1993 vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
1994 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
1995 Utf8,
1996 schema
1997 );
1998 Ok(())
1999 }
2000
2001 #[test]
2002 fn test_then_else_list() -> Result<()> {
2003 let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
2004 let schema = Arc::new(DFSchema::from_unqualified_fields(
2005 vec![
2006 Field::new("boolean", DataType::Boolean, true),
2007 Field::new(
2008 "large_list",
2009 DataType::LargeList(Arc::clone(&inner_field)),
2010 true,
2011 ),
2012 Field::new(
2013 "fixed_list",
2014 DataType::FixedSizeList(Arc::clone(&inner_field), 3),
2015 true,
2016 ),
2017 Field::new("list", DataType::List(inner_field), true),
2018 ]
2019 .into(),
2020 std::collections::HashMap::new(),
2021 )?);
2022
2023 test_case_expression!(
2025 None::<String>,
2026 vec![
2027 (Box::new(col("boolean")), Box::new(col("large_list"))),
2028 (Box::new(col("boolean")), Box::new(col("list")))
2029 ],
2030 DataType::Boolean,
2031 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2032 schema
2033 );
2034
2035 test_case_expression!(
2036 None::<String>,
2037 vec![
2038 (Box::new(col("boolean")), Box::new(col("list"))),
2039 (Box::new(col("boolean")), Box::new(col("large_list")))
2040 ],
2041 DataType::Boolean,
2042 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2043 schema
2044 );
2045
2046 test_case_expression!(
2048 None::<String>,
2049 vec![
2050 (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2051 (Box::new(col("boolean")), Box::new(col("list")))
2052 ],
2053 DataType::Boolean,
2054 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2055 schema
2056 );
2057
2058 test_case_expression!(
2059 None::<String>,
2060 vec![
2061 (Box::new(col("boolean")), Box::new(col("list"))),
2062 (Box::new(col("boolean")), Box::new(col("fixed_list")))
2063 ],
2064 DataType::Boolean,
2065 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2066 schema
2067 );
2068
2069 test_case_expression!(
2071 None::<String>,
2072 vec![
2073 (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2074 (Box::new(col("boolean")), Box::new(col("large_list")))
2075 ],
2076 DataType::Boolean,
2077 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2078 schema
2079 );
2080
2081 test_case_expression!(
2082 None::<String>,
2083 vec![
2084 (Box::new(col("boolean")), Box::new(col("large_list"))),
2085 (Box::new(col("boolean")), Box::new(col("fixed_list")))
2086 ],
2087 DataType::Boolean,
2088 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2089 schema
2090 );
2091 Ok(())
2092 }
2093
2094 #[test]
2095 fn test_map_with_diff_name() -> Result<()> {
2096 let mut builder = SchemaBuilder::new();
2097 builder.push(Field::new("key", Utf8, false));
2098 builder.push(Field::new("value", DataType::Float64, true));
2099 let struct_fields = builder.finish().fields;
2100
2101 let fields =
2102 Field::new("entries", DataType::Struct(struct_fields.clone()), false);
2103 let map_type_entries = DataType::Map(Arc::new(fields), false);
2104
2105 let fields = Field::new("key_value", DataType::Struct(struct_fields), false);
2106 let may_type_cutsom = DataType::Map(Arc::new(fields), false);
2107
2108 let expr = col("a").eq(cast(col("a"), may_type_cutsom));
2109 let empty = empty_with_type(map_type_entries);
2110 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2111 let expected = "Projection: a = CAST(CAST(a AS Map(Field { name: \"key_value\", data_type: Struct([Field { name: \"key\", data_type: Utf8, \
2112 nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: \"value\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), \
2113 nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false)) AS Map(Field { name: \"entries\", data_type: Struct([Field { name: \"key\", data_type: Utf8, nullable: false, \
2114 dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: \"value\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false))\n \
2115 EmptyRelation";
2116 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)
2117 }
2118
2119 #[test]
2120 fn interval_plus_timestamp() -> Result<()> {
2121 let expr = Expr::BinaryExpr(BinaryExpr::new(
2123 Box::new(lit(ScalarValue::IntervalYearMonth(Some(12)))),
2124 Operator::Plus,
2125 Box::new(cast(
2126 lit("2000-01-01T00:00:00"),
2127 DataType::Timestamp(TimeUnit::Nanosecond, None),
2128 )),
2129 ));
2130 let empty = empty();
2131 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2132 let expected = "Projection: IntervalYearMonth(\"12\") + CAST(Utf8(\"2000-01-01T00:00:00\") AS Timestamp(Nanosecond, None))\n EmptyRelation";
2133 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
2134 Ok(())
2135 }
2136
2137 #[test]
2138 fn timestamp_subtract_timestamp() -> Result<()> {
2139 let expr = Expr::BinaryExpr(BinaryExpr::new(
2140 Box::new(cast(
2141 lit("1998-03-18"),
2142 DataType::Timestamp(TimeUnit::Nanosecond, None),
2143 )),
2144 Operator::Minus,
2145 Box::new(cast(
2146 lit("1998-03-18"),
2147 DataType::Timestamp(TimeUnit::Nanosecond, None),
2148 )),
2149 ));
2150 let empty = empty();
2151 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2152 let expected =
2153 "Projection: CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None)) - CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None))\n EmptyRelation";
2154 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
2155 Ok(())
2156 }
2157
2158 #[test]
2159 fn in_subquery_cast_subquery() -> Result<()> {
2160 let empty_int32 = empty_with_type(DataType::Int32);
2161 let empty_int64 = empty_with_type(DataType::Int64);
2162
2163 let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2164 Box::new(col("a")),
2165 Subquery {
2166 subquery: empty_int32,
2167 outer_ref_columns: vec![],
2168 spans: Spans::new(),
2169 },
2170 false,
2171 ));
2172 let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int64)?);
2173 let expected = "\
2175 Filter: a IN (<subquery>)\
2176 \n Subquery:\
2177 \n Projection: CAST(a AS Int64)\
2178 \n EmptyRelation\
2179 \n EmptyRelation";
2180 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
2181 Ok(())
2182 }
2183
2184 #[test]
2185 fn in_subquery_cast_expr() -> Result<()> {
2186 let empty_int32 = empty_with_type(DataType::Int32);
2187 let empty_int64 = empty_with_type(DataType::Int64);
2188
2189 let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2190 Box::new(col("a")),
2191 Subquery {
2192 subquery: empty_int64,
2193 outer_ref_columns: vec![],
2194 spans: Spans::new(),
2195 },
2196 false,
2197 ));
2198 let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int32)?);
2199 let expected = "\
2201 Filter: CAST(a AS Int64) IN (<subquery>)\
2202 \n Subquery:\
2203 \n EmptyRelation\
2204 \n EmptyRelation";
2205 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
2206 Ok(())
2207 }
2208
2209 #[test]
2210 fn in_subquery_cast_all() -> Result<()> {
2211 let empty_inside = empty_with_type(DataType::Decimal128(10, 5));
2212 let empty_outside = empty_with_type(DataType::Decimal128(8, 8));
2213
2214 let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2215 Box::new(col("a")),
2216 Subquery {
2217 subquery: empty_inside,
2218 outer_ref_columns: vec![],
2219 spans: Spans::new(),
2220 },
2221 false,
2222 ));
2223 let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_outside)?);
2224 let expected = "Filter: CAST(a AS Decimal128(13, 8)) IN (<subquery>)\
2226 \n Subquery:\
2227 \n Projection: CAST(a AS Decimal128(13, 8))\
2228 \n EmptyRelation\
2229 \n EmptyRelation";
2230 assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
2231 Ok(())
2232 }
2233}