1use std::sync::Arc;
21
22use datafusion_expr::binary::BinaryTypeCoercer;
23use itertools::{izip, Itertools as _};
24
25use arrow::datatypes::{DataType, Field, IntervalUnit, Schema};
26
27use crate::analyzer::AnalyzerRule;
28use crate::utils::NamePreserver;
29use datafusion_common::config::ConfigOptions;
30use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
31use datafusion_common::{
32 exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err,
33 plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
34 TableReference,
35};
36use datafusion_expr::expr::{
37 self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, InList,
38 InSubquery, Like, ScalarFunction, Sort, WindowFunction,
39};
40use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
41use datafusion_expr::expr_schema::cast_subquery;
42use datafusion_expr::logical_plan::Subquery;
43use datafusion_expr::type_coercion::binary::{comparison_coercion, like_coercion};
44use datafusion_expr::type_coercion::functions::{
45 data_types_with_scalar_udf, fields_with_aggregate_udf,
46};
47use datafusion_expr::type_coercion::other::{
48 get_coerce_type_for_case_expression, get_coerce_type_for_list,
49};
50use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_utf8view_or_large_utf8};
51use datafusion_expr::utils::merge_schema;
52use datafusion_expr::{
53 is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not,
54 AggregateUDF, Expr, ExprSchemable, Join, Limit, LogicalPlan, Operator, Projection,
55 ScalarUDF, Union, WindowFrame, WindowFrameBound, WindowFrameUnits,
56};
57
58#[derive(Default, Debug)]
61pub struct TypeCoercion {}
62
63impl TypeCoercion {
64 pub fn new() -> Self {
65 Self {}
66 }
67}
68
69fn coerce_output(plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
71 if !config.optimizer.expand_views_at_output {
72 return Ok(plan);
73 }
74
75 let outer_refs = plan.expressions();
76 if outer_refs.is_empty() {
77 return Ok(plan);
78 }
79
80 if let Some(dfschema) = transform_schema_to_nonview(plan.schema()) {
81 coerce_plan_expr_for_schema(plan, &dfschema?)
82 } else {
83 Ok(plan)
84 }
85}
86
87impl AnalyzerRule for TypeCoercion {
88 fn name(&self) -> &str {
89 "type_coercion"
90 }
91
92 fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
93 let empty_schema = DFSchema::empty();
94
95 let transformed_plan = plan
97 .transform_up_with_subqueries(|plan| analyze_internal(&empty_schema, plan))?
98 .data;
99
100 coerce_output(transformed_plan, config)
102 }
103}
104
105fn analyze_internal(
109 external_schema: &DFSchema,
110 plan: LogicalPlan,
111) -> Result<Transformed<LogicalPlan>> {
112 let mut schema = merge_schema(&plan.inputs());
115
116 if let LogicalPlan::TableScan(ts) = &plan {
117 let source_schema = DFSchema::try_from_qualified_schema(
118 ts.table_name.clone(),
119 &ts.source.schema(),
120 )?;
121 schema.merge(&source_schema);
122 }
123
124 schema.merge(external_schema);
128
129 let plan = if let LogicalPlan::Filter(mut filter) = plan {
131 filter.predicate = filter.predicate.cast_to(&DataType::Boolean, &schema)?;
132 LogicalPlan::Filter(filter)
133 } else {
134 plan
135 };
136
137 let mut expr_rewrite = TypeCoercionRewriter::new(&schema);
138
139 let name_preserver = NamePreserver::new(&plan);
140 plan.map_expressions(|expr| {
142 let original_name = name_preserver.save(&expr);
143 expr.rewrite(&mut expr_rewrite)
144 .map(|transformed| transformed.update_data(|e| original_name.restore(e)))
145 })?
146 .map_data(|plan| expr_rewrite.coerce_plan(plan))?
148 .map_data(|plan| plan.recompute_schema())
150}
151
152pub struct TypeCoercionRewriter<'a> {
154 pub(crate) schema: &'a DFSchema,
155}
156
157impl<'a> TypeCoercionRewriter<'a> {
158 pub fn new(schema: &'a DFSchema) -> Self {
161 Self { schema }
162 }
163
164 pub fn coerce_plan(&mut self, plan: LogicalPlan) -> Result<LogicalPlan> {
169 match plan {
170 LogicalPlan::Join(join) => self.coerce_join(join),
171 LogicalPlan::Union(union) => Self::coerce_union(union),
172 LogicalPlan::Limit(limit) => Self::coerce_limit(limit),
173 _ => Ok(plan),
174 }
175 }
176
177 pub fn coerce_join(&mut self, mut join: Join) -> Result<LogicalPlan> {
186 join.on = join
187 .on
188 .into_iter()
189 .map(|(lhs, rhs)| {
190 let left_schema = join.left.schema();
193 let right_schema = join.right.schema();
194 let (lhs, rhs) = self.coerce_binary_op(
195 lhs,
196 left_schema,
197 Operator::Eq,
198 rhs,
199 right_schema,
200 )?;
201 Ok((lhs, rhs))
202 })
203 .collect::<Result<Vec<_>>>()?;
204
205 join.filter = join
207 .filter
208 .map(|expr| self.coerce_join_filter(expr))
209 .transpose()?;
210
211 Ok(LogicalPlan::Join(join))
212 }
213
214 pub fn coerce_union(union_plan: Union) -> Result<LogicalPlan> {
217 let union_schema = Arc::new(coerce_union_schema_with_schema(
218 &union_plan.inputs,
219 &union_plan.schema,
220 )?);
221 let new_inputs = union_plan
222 .inputs
223 .into_iter()
224 .map(|p| {
225 let plan =
226 coerce_plan_expr_for_schema(Arc::unwrap_or_clone(p), &union_schema)?;
227 match plan {
228 LogicalPlan::Projection(Projection { expr, input, .. }) => {
229 Ok(Arc::new(project_with_column_index(
230 expr,
231 input,
232 Arc::clone(&union_schema),
233 )?))
234 }
235 other_plan => Ok(Arc::new(other_plan)),
236 }
237 })
238 .collect::<Result<Vec<_>>>()?;
239 Ok(LogicalPlan::Union(Union {
240 inputs: new_inputs,
241 schema: union_schema,
242 }))
243 }
244
245 fn coerce_limit(limit: Limit) -> Result<LogicalPlan> {
247 fn coerce_limit_expr(
248 expr: Expr,
249 schema: &DFSchema,
250 expr_name: &str,
251 ) -> Result<Expr> {
252 let dt = expr.get_type(schema)?;
253 if dt.is_integer() || dt.is_null() {
254 expr.cast_to(&DataType::Int64, schema)
255 } else {
256 plan_err!("Expected {expr_name} to be an integer or null, but got {dt}")
257 }
258 }
259
260 let empty_schema = DFSchema::empty();
261 let new_fetch = limit
262 .fetch
263 .map(|expr| coerce_limit_expr(*expr, &empty_schema, "LIMIT"))
264 .transpose()?;
265 let new_skip = limit
266 .skip
267 .map(|expr| coerce_limit_expr(*expr, &empty_schema, "OFFSET"))
268 .transpose()?;
269 Ok(LogicalPlan::Limit(Limit {
270 input: limit.input,
271 fetch: new_fetch.map(Box::new),
272 skip: new_skip.map(Box::new),
273 }))
274 }
275
276 fn coerce_join_filter(&self, expr: Expr) -> Result<Expr> {
277 let expr_type = expr.get_type(self.schema)?;
278 match expr_type {
279 DataType::Boolean => Ok(expr),
280 DataType::Null => expr.cast_to(&DataType::Boolean, self.schema),
281 other => plan_err!("Join condition must be boolean type, but got {other:?}"),
282 }
283 }
284
285 fn coerce_binary_op(
286 &self,
287 left: Expr,
288 left_schema: &DFSchema,
289 op: Operator,
290 right: Expr,
291 right_schema: &DFSchema,
292 ) -> Result<(Expr, Expr)> {
293 let (left_type, right_type) = BinaryTypeCoercer::new(
294 &left.get_type(left_schema)?,
295 &op,
296 &right.get_type(right_schema)?,
297 )
298 .get_input_types()?;
299
300 Ok((
301 left.cast_to(&left_type, left_schema)?,
302 right.cast_to(&right_type, right_schema)?,
303 ))
304 }
305}
306
307impl TreeNodeRewriter for TypeCoercionRewriter<'_> {
308 type Node = Expr;
309
310 fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
311 match expr {
312 Expr::Unnest(_) => not_impl_err!(
313 "Unnest should be rewritten to LogicalPlan::Unnest before type coercion"
314 ),
315 Expr::ScalarSubquery(Subquery {
316 subquery,
317 outer_ref_columns,
318 spans,
319 }) => {
320 let new_plan =
321 analyze_internal(self.schema, Arc::unwrap_or_clone(subquery))?.data;
322 Ok(Transformed::yes(Expr::ScalarSubquery(Subquery {
323 subquery: Arc::new(new_plan),
324 outer_ref_columns,
325 spans,
326 })))
327 }
328 Expr::Exists(Exists { subquery, negated }) => {
329 let new_plan = analyze_internal(
330 self.schema,
331 Arc::unwrap_or_clone(subquery.subquery),
332 )?
333 .data;
334 Ok(Transformed::yes(Expr::Exists(Exists {
335 subquery: Subquery {
336 subquery: Arc::new(new_plan),
337 outer_ref_columns: subquery.outer_ref_columns,
338 spans: subquery.spans,
339 },
340 negated,
341 })))
342 }
343 Expr::InSubquery(InSubquery {
344 expr,
345 subquery,
346 negated,
347 }) => {
348 let new_plan = analyze_internal(
349 self.schema,
350 Arc::unwrap_or_clone(subquery.subquery),
351 )?
352 .data;
353 let expr_type = expr.get_type(self.schema)?;
354 let subquery_type = new_plan.schema().field(0).data_type();
355 let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(
356 plan_datafusion_err!(
357 "expr type {expr_type} can't cast to {subquery_type} in InSubquery"
358 ),
359 )?;
360 let new_subquery = Subquery {
361 subquery: Arc::new(new_plan),
362 outer_ref_columns: subquery.outer_ref_columns,
363 spans: subquery.spans,
364 };
365 Ok(Transformed::yes(Expr::InSubquery(InSubquery::new(
366 Box::new(expr.cast_to(&common_type, self.schema)?),
367 cast_subquery(new_subquery, &common_type)?,
368 negated,
369 ))))
370 }
371 Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op(
372 *expr,
373 self.schema,
374 )?))),
375 Expr::IsTrue(expr) => Ok(Transformed::yes(is_true(
376 get_casted_expr_for_bool_op(*expr, self.schema)?,
377 ))),
378 Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true(
379 get_casted_expr_for_bool_op(*expr, self.schema)?,
380 ))),
381 Expr::IsFalse(expr) => Ok(Transformed::yes(is_false(
382 get_casted_expr_for_bool_op(*expr, self.schema)?,
383 ))),
384 Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false(
385 get_casted_expr_for_bool_op(*expr, self.schema)?,
386 ))),
387 Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown(
388 get_casted_expr_for_bool_op(*expr, self.schema)?,
389 ))),
390 Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown(
391 get_casted_expr_for_bool_op(*expr, self.schema)?,
392 ))),
393 Expr::Like(Like {
394 negated,
395 expr,
396 pattern,
397 escape_char,
398 case_insensitive,
399 }) => {
400 let left_type = expr.get_type(self.schema)?;
401 let right_type = pattern.get_type(self.schema)?;
402 let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| {
403 let op_name = if case_insensitive {
404 "ILIKE"
405 } else {
406 "LIKE"
407 };
408 plan_datafusion_err!(
409 "There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression"
410 )
411 })?;
412 let expr = match left_type {
413 DataType::Dictionary(_, inner) if *inner == DataType::Utf8 => expr,
414 _ => Box::new(expr.cast_to(&coerced_type, self.schema)?),
415 };
416 let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?);
417 Ok(Transformed::yes(Expr::Like(Like::new(
418 negated,
419 expr,
420 pattern,
421 escape_char,
422 case_insensitive,
423 ))))
424 }
425 Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
426 let (left, right) =
427 self.coerce_binary_op(*left, self.schema, op, *right, self.schema)?;
428 Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new(
429 Box::new(left),
430 op,
431 Box::new(right),
432 ))))
433 }
434 Expr::Between(Between {
435 expr,
436 negated,
437 low,
438 high,
439 }) => {
440 let expr_type = expr.get_type(self.schema)?;
441 let low_type = low.get_type(self.schema)?;
442 let low_coerced_type = comparison_coercion(&expr_type, &low_type)
443 .ok_or_else(|| {
444 internal_datafusion_err!(
445 "Failed to coerce types {expr_type} and {low_type} in BETWEEN expression"
446 )
447 })?;
448 let high_type = high.get_type(self.schema)?;
449 let high_coerced_type = comparison_coercion(&expr_type, &high_type)
450 .ok_or_else(|| {
451 internal_datafusion_err!(
452 "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
453 )
454 })?;
455 let coercion_type =
456 comparison_coercion(&low_coerced_type, &high_coerced_type)
457 .ok_or_else(|| {
458 internal_datafusion_err!(
459 "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
460 )
461 })?;
462 Ok(Transformed::yes(Expr::Between(Between::new(
463 Box::new(expr.cast_to(&coercion_type, self.schema)?),
464 negated,
465 Box::new(low.cast_to(&coercion_type, self.schema)?),
466 Box::new(high.cast_to(&coercion_type, self.schema)?),
467 ))))
468 }
469 Expr::InList(InList {
470 expr,
471 list,
472 negated,
473 }) => {
474 let expr_data_type = expr.get_type(self.schema)?;
475 let list_data_types = list
476 .iter()
477 .map(|list_expr| list_expr.get_type(self.schema))
478 .collect::<Result<Vec<_>>>()?;
479 let result_type =
480 get_coerce_type_for_list(&expr_data_type, &list_data_types);
481 match result_type {
482 None => plan_err!(
483 "Can not find compatible types to compare {expr_data_type} with [{}]", list_data_types.iter().join(", ")
484 ),
485 Some(coerced_type) => {
486 let cast_expr = expr.cast_to(&coerced_type, self.schema)?;
488 let cast_list_expr = list
489 .into_iter()
490 .map(|list_expr| {
491 list_expr.cast_to(&coerced_type, self.schema)
492 })
493 .collect::<Result<Vec<_>>>()?;
494 Ok(Transformed::yes(Expr::InList(InList ::new(
495 Box::new(cast_expr),
496 cast_list_expr,
497 negated,
498 ))))
499 }
500 }
501 }
502 Expr::Case(case) => {
503 let case = coerce_case_expression(case, self.schema)?;
504 Ok(Transformed::yes(Expr::Case(case)))
505 }
506 Expr::ScalarFunction(ScalarFunction { func, args }) => {
507 let new_expr = coerce_arguments_for_signature_with_scalar_udf(
508 args,
509 self.schema,
510 &func,
511 )?;
512 Ok(Transformed::yes(Expr::ScalarFunction(
513 ScalarFunction::new_udf(func, new_expr),
514 )))
515 }
516 Expr::AggregateFunction(expr::AggregateFunction {
517 func,
518 params:
519 AggregateFunctionParams {
520 args,
521 distinct,
522 filter,
523 order_by,
524 null_treatment,
525 },
526 }) => {
527 let new_expr = coerce_arguments_for_signature_with_aggregate_udf(
528 args,
529 self.schema,
530 &func,
531 )?;
532 Ok(Transformed::yes(Expr::AggregateFunction(
533 expr::AggregateFunction::new_udf(
534 func,
535 new_expr,
536 distinct,
537 filter,
538 order_by,
539 null_treatment,
540 ),
541 )))
542 }
543 Expr::WindowFunction(window_fun) => {
544 let WindowFunction {
545 fun,
546 params:
547 expr::WindowFunctionParams {
548 args,
549 partition_by,
550 order_by,
551 window_frame,
552 filter,
553 null_treatment,
554 distinct,
555 },
556 } = *window_fun;
557 let window_frame =
558 coerce_window_frame(window_frame, self.schema, &order_by)?;
559
560 let args = match &fun {
561 expr::WindowFunctionDefinition::AggregateUDF(udf) => {
562 coerce_arguments_for_signature_with_aggregate_udf(
563 args,
564 self.schema,
565 udf,
566 )?
567 }
568 _ => args,
569 };
570
571 let new_expr = Expr::from(WindowFunction {
572 fun,
573 params: expr::WindowFunctionParams {
574 args,
575 partition_by,
576 order_by,
577 window_frame,
578 filter,
579 null_treatment,
580 distinct,
581 },
582 });
583 Ok(Transformed::yes(new_expr))
584 }
585 #[expect(deprecated)]
587 Expr::Alias(_)
588 | Expr::Column(_)
589 | Expr::ScalarVariable(_, _)
590 | Expr::Literal(_, _)
591 | Expr::SimilarTo(_)
592 | Expr::IsNotNull(_)
593 | Expr::IsNull(_)
594 | Expr::Negative(_)
595 | Expr::Cast(_)
596 | Expr::TryCast(_)
597 | Expr::Wildcard { .. }
598 | Expr::GroupingSet(_)
599 | Expr::Placeholder(_)
600 | Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)),
601 }
602 }
603}
604
605fn transform_schema_to_nonview(dfschema: &DFSchemaRef) -> Option<Result<DFSchema>> {
607 let metadata = dfschema.as_arrow().metadata.clone();
608 let mut transformed = false;
609
610 let (qualifiers, transformed_fields): (Vec<Option<TableReference>>, Vec<Arc<Field>>) =
611 dfschema
612 .iter()
613 .map(|(qualifier, field)| match field.data_type() {
614 DataType::Utf8View => {
615 transformed = true;
616 (
617 qualifier.cloned() as Option<TableReference>,
618 Arc::new(Field::new(
619 field.name(),
620 DataType::LargeUtf8,
621 field.is_nullable(),
622 )),
623 )
624 }
625 DataType::BinaryView => {
626 transformed = true;
627 (
628 qualifier.cloned() as Option<TableReference>,
629 Arc::new(Field::new(
630 field.name(),
631 DataType::LargeBinary,
632 field.is_nullable(),
633 )),
634 )
635 }
636 _ => (
637 qualifier.cloned() as Option<TableReference>,
638 Arc::clone(field),
639 ),
640 })
641 .unzip();
642
643 if !transformed {
644 return None;
645 }
646
647 let schema = Schema::new_with_metadata(transformed_fields, metadata);
648 Some(DFSchema::from_field_specific_qualified_schema(
649 qualifiers,
650 &Arc::new(schema),
651 ))
652}
653
654fn coerce_scalar(target_type: &DataType, value: &ScalarValue) -> Result<ScalarValue> {
657 match value {
658 ScalarValue::Utf8(Some(val)) => {
660 ScalarValue::try_from_string(val.clone(), target_type)
661 }
662 s => {
663 if s.is_null() {
664 ScalarValue::try_from(target_type)
666 } else {
667 Ok(s.clone())
671 }
672 }
673 }
674}
675
676fn coerce_scalar_range_aware(
683 target_type: &DataType,
684 value: &ScalarValue,
685) -> Result<ScalarValue> {
686 coerce_scalar(target_type, value).or_else(|err| {
687 if let Some(largest_type) = get_widest_type_in_family(target_type) {
689 coerce_scalar(largest_type, value).map_or_else(
690 |_| exec_err!("Cannot cast {value:?} to {target_type}"),
691 |_| ScalarValue::try_from(target_type),
692 )
693 } else {
694 Err(err)
695 }
696 })
697}
698
699fn get_widest_type_in_family(given_type: &DataType) -> Option<&DataType> {
703 match given_type {
704 DataType::UInt8 | DataType::UInt16 | DataType::UInt32 => Some(&DataType::UInt64),
705 DataType::Int8 | DataType::Int16 | DataType::Int32 => Some(&DataType::Int64),
706 DataType::Float16 | DataType::Float32 => Some(&DataType::Float64),
707 _ => None,
708 }
709}
710
711fn coerce_frame_bound(
713 target_type: &DataType,
714 bound: WindowFrameBound,
715) -> Result<WindowFrameBound> {
716 match bound {
717 WindowFrameBound::Preceding(v) => {
718 coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Preceding)
719 }
720 WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow),
721 WindowFrameBound::Following(v) => {
722 coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Following)
723 }
724 }
725}
726
727fn extract_window_frame_target_type(col_type: &DataType) -> Result<DataType> {
728 if col_type.is_numeric()
729 || is_utf8_or_utf8view_or_large_utf8(col_type)
730 || matches!(col_type, DataType::List(_))
731 || matches!(col_type, DataType::LargeList(_))
732 || matches!(col_type, DataType::FixedSizeList(_, _))
733 || matches!(col_type, DataType::Null)
734 || matches!(col_type, DataType::Boolean)
735 {
736 Ok(col_type.clone())
737 } else if is_datetime(col_type) {
738 Ok(DataType::Interval(IntervalUnit::MonthDayNano))
739 } else if let DataType::Dictionary(_, value_type) = col_type {
740 extract_window_frame_target_type(value_type)
741 } else {
742 internal_err!("Cannot run range queries on datatype: {col_type}")
743 }
744}
745
746fn coerce_window_frame(
749 window_frame: WindowFrame,
750 schema: &DFSchema,
751 expressions: &[Sort],
752) -> Result<WindowFrame> {
753 let mut window_frame = window_frame;
754 let target_type = match window_frame.units {
755 WindowFrameUnits::Range => {
756 let current_types = expressions
757 .first()
758 .map(|s| s.expr.get_type(schema))
759 .transpose()?;
760 if let Some(col_type) = current_types {
761 extract_window_frame_target_type(&col_type)?
762 } else {
763 return internal_err!("ORDER BY column cannot be empty");
764 }
765 }
766 WindowFrameUnits::Rows | WindowFrameUnits::Groups => DataType::UInt64,
767 };
768 window_frame.start_bound =
769 coerce_frame_bound(&target_type, window_frame.start_bound)?;
770 window_frame.end_bound = coerce_frame_bound(&target_type, window_frame.end_bound)?;
771 Ok(window_frame)
772}
773
774fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result<Expr> {
777 let left_type = expr.get_type(schema)?;
778 BinaryTypeCoercer::new(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)
779 .get_input_types()?;
780 expr.cast_to(&DataType::Boolean, schema)
781}
782
783fn coerce_arguments_for_signature_with_scalar_udf(
788 expressions: Vec<Expr>,
789 schema: &DFSchema,
790 func: &ScalarUDF,
791) -> Result<Vec<Expr>> {
792 if expressions.is_empty() {
793 return Ok(expressions);
794 }
795
796 let current_types = expressions
797 .iter()
798 .map(|e| e.get_type(schema))
799 .collect::<Result<Vec<_>>>()?;
800
801 let new_types = data_types_with_scalar_udf(¤t_types, func)?;
802
803 expressions
804 .into_iter()
805 .enumerate()
806 .map(|(i, expr)| expr.cast_to(&new_types[i], schema))
807 .collect()
808}
809
810fn coerce_arguments_for_signature_with_aggregate_udf(
815 expressions: Vec<Expr>,
816 schema: &DFSchema,
817 func: &AggregateUDF,
818) -> Result<Vec<Expr>> {
819 if expressions.is_empty() {
820 return Ok(expressions);
821 }
822
823 let current_fields = expressions
824 .iter()
825 .map(|e| e.to_field(schema).map(|(_, f)| f))
826 .collect::<Result<Vec<_>>>()?;
827
828 let new_types = fields_with_aggregate_udf(¤t_fields, func)?
829 .into_iter()
830 .map(|f| f.data_type().clone())
831 .collect::<Vec<_>>();
832
833 expressions
834 .into_iter()
835 .enumerate()
836 .map(|(i, expr)| expr.cast_to(&new_types[i], schema))
837 .collect()
838}
839
840fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result<Case> {
841 let case_type = case
873 .expr
874 .as_ref()
875 .map(|expr| expr.get_type(schema))
876 .transpose()?;
877 let then_types = case
878 .when_then_expr
879 .iter()
880 .map(|(_when, then)| then.get_type(schema))
881 .collect::<Result<Vec<_>>>()?;
882 let else_type = case
883 .else_expr
884 .as_ref()
885 .map(|expr| expr.get_type(schema))
886 .transpose()?;
887
888 let case_when_coerce_type = case_type
890 .as_ref()
891 .map(|case_type| {
892 let when_types = case
893 .when_then_expr
894 .iter()
895 .map(|(when, _then)| when.get_type(schema))
896 .collect::<Result<Vec<_>>>()?;
897 let coerced_type =
898 get_coerce_type_for_case_expression(&when_types, Some(case_type));
899 coerced_type.ok_or_else(|| {
900 plan_datafusion_err!(
901 "Failed to coerce case ({case_type}) and when ({}) \
902 to common types in CASE WHEN expression",
903 when_types.iter().join(", ")
904 )
905 })
906 })
907 .transpose()?;
908 let then_else_coerce_type =
909 get_coerce_type_for_case_expression(&then_types, else_type.as_ref()).ok_or_else(
910 || {
911 if let Some(else_type) = else_type {
912 plan_datafusion_err!(
913 "Failed to coerce then ({}) and else ({else_type}) \
914 to common types in CASE WHEN expression",
915 then_types.iter().join(", ")
916 )
917 } else {
918 plan_datafusion_err!(
919 "Failed to coerce then ({}) and else (None) \
920 to common types in CASE WHEN expression",
921 then_types.iter().join(", ")
922 )
923 }
924 },
925 )?;
926
927 let case_expr = case
929 .expr
930 .zip(case_when_coerce_type.as_ref())
931 .map(|(case_expr, coercible_type)| case_expr.cast_to(coercible_type, schema))
932 .transpose()?
933 .map(Box::new);
934 let when_then = case
935 .when_then_expr
936 .into_iter()
937 .map(|(when, then)| {
938 let when_type = case_when_coerce_type.as_ref().unwrap_or(&DataType::Boolean);
939 let when = when.cast_to(when_type, schema).map_err(|e| {
940 DataFusionError::Context(
941 format!(
942 "WHEN expressions in CASE couldn't be \
943 converted to common type ({when_type})"
944 ),
945 Box::new(e),
946 )
947 })?;
948 let then = then.cast_to(&then_else_coerce_type, schema)?;
949 Ok((Box::new(when), Box::new(then)))
950 })
951 .collect::<Result<Vec<_>>>()?;
952 let else_expr = case
953 .else_expr
954 .map(|expr| expr.cast_to(&then_else_coerce_type, schema))
955 .transpose()?
956 .map(Box::new);
957
958 Ok(Case::new(case_expr, when_then, else_expr))
959}
960
961pub fn coerce_union_schema(inputs: &[Arc<LogicalPlan>]) -> Result<DFSchema> {
1003 coerce_union_schema_with_schema(&inputs[1..], inputs[0].schema())
1004}
1005fn coerce_union_schema_with_schema(
1006 inputs: &[Arc<LogicalPlan>],
1007 base_schema: &DFSchemaRef,
1008) -> Result<DFSchema> {
1009 let mut union_datatypes = base_schema
1010 .fields()
1011 .iter()
1012 .map(|f| f.data_type().clone())
1013 .collect::<Vec<_>>();
1014 let mut union_nullabilities = base_schema
1015 .fields()
1016 .iter()
1017 .map(|f| f.is_nullable())
1018 .collect::<Vec<_>>();
1019 let mut union_field_meta = base_schema
1020 .fields()
1021 .iter()
1022 .map(|f| f.metadata().clone())
1023 .collect::<Vec<_>>();
1024
1025 let mut metadata = base_schema.metadata().clone();
1026
1027 for (i, plan) in inputs.iter().enumerate() {
1028 let plan_schema = plan.schema();
1029 metadata.extend(plan_schema.metadata().clone());
1030
1031 if plan_schema.fields().len() != base_schema.fields().len() {
1032 return plan_err!(
1033 "Union schemas have different number of fields: \
1034 query 1 has {} fields whereas query {} has {} fields",
1035 base_schema.fields().len(),
1036 i + 1,
1037 plan_schema.fields().len()
1038 );
1039 }
1040
1041 for (union_datatype, union_nullable, union_field_map, plan_field) in izip!(
1043 union_datatypes.iter_mut(),
1044 union_nullabilities.iter_mut(),
1045 union_field_meta.iter_mut(),
1046 plan_schema.fields().iter()
1047 ) {
1048 let coerced_type =
1049 comparison_coercion(union_datatype, plan_field.data_type()).ok_or_else(
1050 || {
1051 plan_datafusion_err!(
1052 "Incompatible inputs for Union: Previous inputs were \
1053 of type {}, but got incompatible type {} on column '{}'",
1054 union_datatype,
1055 plan_field.data_type(),
1056 plan_field.name()
1057 )
1058 },
1059 )?;
1060
1061 *union_datatype = coerced_type;
1062 *union_nullable = *union_nullable || plan_field.is_nullable();
1063 union_field_map.extend(plan_field.metadata().clone());
1064 }
1065 }
1066 let union_qualified_fields = izip!(
1067 base_schema.fields(),
1068 union_datatypes.into_iter(),
1069 union_nullabilities,
1070 union_field_meta.into_iter()
1071 )
1072 .map(|(field, datatype, nullable, metadata)| {
1073 let mut field = Field::new(field.name().clone(), datatype, nullable);
1074 field.set_metadata(metadata);
1075 (None, field.into())
1076 })
1077 .collect::<Vec<_>>();
1078
1079 DFSchema::new_with_metadata(union_qualified_fields, metadata)
1080}
1081
1082fn project_with_column_index(
1084 expr: Vec<Expr>,
1085 input: Arc<LogicalPlan>,
1086 schema: DFSchemaRef,
1087) -> Result<LogicalPlan> {
1088 let alias_expr = expr
1089 .into_iter()
1090 .enumerate()
1091 .map(|(i, e)| match e {
1092 Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => {
1093 Ok(e.unalias().alias(schema.field(i).name()))
1094 }
1095 Expr::Column(Column {
1096 relation: _,
1097 ref name,
1098 spans: _,
1099 }) if name != schema.field(i).name() => Ok(e.alias(schema.field(i).name())),
1100 Expr::Alias { .. } | Expr::Column { .. } => Ok(e),
1101 #[expect(deprecated)]
1102 Expr::Wildcard { .. } => {
1103 plan_err!("Wildcard should be expanded before type coercion")
1104 }
1105 _ => Ok(e.alias(schema.field(i).name())),
1106 })
1107 .collect::<Result<Vec<_>>>()?;
1108
1109 Projection::try_new_with_schema(alias_expr, input, schema)
1110 .map(LogicalPlan::Projection)
1111}
1112
1113#[cfg(test)]
1114mod test {
1115 use std::any::Any;
1116 use std::sync::Arc;
1117
1118 use arrow::datatypes::DataType::Utf8;
1119 use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder, TimeUnit};
1120 use insta::assert_snapshot;
1121
1122 use crate::analyzer::type_coercion::{
1123 coerce_case_expression, TypeCoercion, TypeCoercionRewriter,
1124 };
1125 use crate::analyzer::Analyzer;
1126 use crate::assert_analyzed_plan_with_config_eq_snapshot;
1127 use datafusion_common::config::ConfigOptions;
1128 use datafusion_common::tree_node::{TransformedResult, TreeNode};
1129 use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue, Spans};
1130 use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction};
1131 use datafusion_expr::logical_plan::{EmptyRelation, Projection, Sort};
1132 use datafusion_expr::test::function_stub::avg_udaf;
1133 use datafusion_expr::{
1134 cast, col, create_udaf, is_true, lit, AccumulatorFactoryFunction, AggregateUDF,
1135 BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Filter, LogicalPlan,
1136 Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
1137 SimpleAggregateUDF, Subquery, Union, Volatility,
1138 };
1139 use datafusion_functions_aggregate::average::AvgAccumulator;
1140 use datafusion_sql::TableReference;
1141
1142 fn empty() -> Arc<LogicalPlan> {
1143 Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1144 produce_one_row: false,
1145 schema: Arc::new(DFSchema::empty()),
1146 }))
1147 }
1148
1149 fn empty_with_type(data_type: DataType) -> Arc<LogicalPlan> {
1150 Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1151 produce_one_row: false,
1152 schema: Arc::new(
1153 DFSchema::from_unqualified_fields(
1154 vec![Field::new("a", data_type, true)].into(),
1155 std::collections::HashMap::new(),
1156 )
1157 .unwrap(),
1158 ),
1159 }))
1160 }
1161
1162 macro_rules! assert_analyzed_plan_eq {
1163 (
1164 $plan: expr,
1165 @ $expected: literal $(,)?
1166 ) => {{
1167 let options = ConfigOptions::default();
1168 let rule = Arc::new(TypeCoercion::new());
1169 assert_analyzed_plan_with_config_eq_snapshot!(
1170 options,
1171 rule,
1172 $plan,
1173 @ $expected,
1174 )
1175 }};
1176 }
1177
1178 macro_rules! coerce_on_output_if_viewtype {
1179 (
1180 $is_viewtype: expr,
1181 $plan: expr,
1182 @ $expected: literal $(,)?
1183 ) => {{
1184 let mut options = ConfigOptions::default();
1185 if $is_viewtype {options.optimizer.expand_views_at_output = true;}
1187 let rule = Arc::new(TypeCoercion::new());
1188
1189 assert_analyzed_plan_with_config_eq_snapshot!(
1190 options,
1191 rule,
1192 $plan,
1193 @ $expected,
1194 )
1195 }};
1196 }
1197
1198 fn assert_type_coercion_error(
1199 plan: LogicalPlan,
1200 expected_substr: &str,
1201 ) -> Result<()> {
1202 let options = ConfigOptions::default();
1203 let analyzer = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]);
1204
1205 match analyzer.execute_and_check(plan, &options, |_, _| {}) {
1206 Ok(succeeded_plan) => {
1207 panic!(
1208 "Expected a type coercion error, but analysis succeeded: \n{succeeded_plan:#?}"
1209 );
1210 }
1211 Err(e) => {
1212 let msg = e.to_string();
1213 assert!(
1214 msg.contains(expected_substr),
1215 "Error did not contain expected substring.\n expected to find: `{expected_substr}`\n actual error: `{msg}`"
1216 );
1217 }
1218 }
1219
1220 Ok(())
1221 }
1222
1223 #[test]
1224 fn simple_case() -> Result<()> {
1225 let expr = col("a").lt(lit(2_u32));
1226 let empty = empty_with_type(DataType::Float64);
1227 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1228
1229 assert_analyzed_plan_eq!(
1230 plan,
1231 @r"
1232 Projection: a < CAST(UInt32(2) AS Float64)
1233 EmptyRelation: rows=0
1234 "
1235 )
1236 }
1237
1238 #[test]
1239 fn test_coerce_union() -> Result<()> {
1240 let left_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1241 produce_one_row: false,
1242 schema: Arc::new(
1243 DFSchema::try_from_qualified_schema(
1244 TableReference::full("datafusion", "test", "foo"),
1245 &Schema::new(vec![Field::new("a", DataType::Int32, false)]),
1246 )
1247 .unwrap(),
1248 ),
1249 }));
1250 let right_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1251 produce_one_row: false,
1252 schema: Arc::new(
1253 DFSchema::try_from_qualified_schema(
1254 TableReference::full("datafusion", "test", "foo"),
1255 &Schema::new(vec![Field::new("a", DataType::Int64, false)]),
1256 )
1257 .unwrap(),
1258 ),
1259 }));
1260 let union = LogicalPlan::Union(Union::try_new_with_loose_types(vec![
1261 left_plan, right_plan,
1262 ])?);
1263 let analyzed_union = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())])
1264 .execute_and_check(union, &ConfigOptions::default(), |_, _| {})?;
1265 let top_level_plan = LogicalPlan::Projection(Projection::try_new(
1266 vec![col("a")],
1267 Arc::new(analyzed_union),
1268 )?);
1269
1270 assert_analyzed_plan_eq!(
1271 top_level_plan,
1272 @r"
1273 Projection: a
1274 Union
1275 Projection: CAST(datafusion.test.foo.a AS Int64) AS a
1276 EmptyRelation: rows=0
1277 EmptyRelation: rows=0
1278 "
1279 )
1280 }
1281
1282 #[test]
1283 fn coerce_utf8view_output() -> Result<()> {
1284 let expr = col("a");
1287 let empty = empty_with_type(DataType::Utf8View);
1288 let plan = LogicalPlan::Projection(Projection::try_new(
1289 vec![expr.clone()],
1290 Arc::clone(&empty),
1291 )?);
1292
1293 coerce_on_output_if_viewtype!(
1295 false,
1296 plan.clone(),
1297 @r"
1298 Projection: a
1299 EmptyRelation: rows=0
1300 "
1301 )?;
1302
1303 coerce_on_output_if_viewtype!(
1305 true,
1306 plan.clone(),
1307 @r"
1308 Projection: CAST(a AS LargeUtf8)
1309 EmptyRelation: rows=0
1310 "
1311 )?;
1312
1313 let bool_expr = col("a").lt(lit("foo"));
1316 let bool_plan = LogicalPlan::Projection(Projection::try_new(
1317 vec![bool_expr],
1318 Arc::clone(&empty),
1319 )?);
1320 coerce_on_output_if_viewtype!(
1322 false,
1323 bool_plan.clone(),
1324 @r#"
1325 Projection: a < CAST(Utf8("foo") AS Utf8View)
1326 EmptyRelation: rows=0
1327 "#
1328 )?;
1329
1330 coerce_on_output_if_viewtype!(
1331 false,
1332 plan.clone(),
1333 @r"
1334 Projection: a
1335 EmptyRelation: rows=0
1336 "
1337 )?;
1338
1339 coerce_on_output_if_viewtype!(
1341 true,
1342 plan.clone(),
1343 @r"
1344 Projection: CAST(a AS LargeUtf8)
1345 EmptyRelation: rows=0
1346 "
1347 )?;
1348
1349 let sort_expr = expr.sort(true, true);
1352 let sort_plan = LogicalPlan::Sort(Sort {
1353 expr: vec![sort_expr],
1354 input: Arc::new(plan),
1355 fetch: None,
1356 });
1357
1358 coerce_on_output_if_viewtype!(
1360 false,
1361 sort_plan.clone(),
1362 @r"
1363 Sort: a ASC NULLS FIRST
1364 Projection: a
1365 EmptyRelation: rows=0
1366 "
1367 )?;
1368
1369 coerce_on_output_if_viewtype!(
1371 true,
1372 sort_plan.clone(),
1373 @r"
1374 Projection: CAST(a AS LargeUtf8)
1375 Sort: a ASC NULLS FIRST
1376 Projection: a
1377 EmptyRelation: rows=0
1378 "
1379 )?;
1380
1381 let plan = LogicalPlan::Projection(Projection::try_new(
1384 vec![col("a")],
1385 Arc::new(sort_plan),
1386 )?);
1387 coerce_on_output_if_viewtype!(
1389 false,
1390 plan.clone(),
1391 @r"
1392 Projection: a
1393 Sort: a ASC NULLS FIRST
1394 Projection: a
1395 EmptyRelation: rows=0
1396 "
1397 )?;
1398 coerce_on_output_if_viewtype!(
1400 true,
1401 plan.clone(),
1402 @r"
1403 Projection: CAST(a AS LargeUtf8)
1404 Sort: a ASC NULLS FIRST
1405 Projection: a
1406 EmptyRelation: rows=0
1407 "
1408 )?;
1409
1410 Ok(())
1411 }
1412
1413 #[test]
1414 fn coerce_binaryview_output() -> Result<()> {
1415 let expr = col("a");
1418 let empty = empty_with_type(DataType::BinaryView);
1419 let plan = LogicalPlan::Projection(Projection::try_new(
1420 vec![expr.clone()],
1421 Arc::clone(&empty),
1422 )?);
1423
1424 coerce_on_output_if_viewtype!(
1426 false,
1427 plan.clone(),
1428 @r"
1429 Projection: a
1430 EmptyRelation: rows=0
1431 "
1432 )?;
1433
1434 coerce_on_output_if_viewtype!(
1436 true,
1437 plan.clone(),
1438 @r"
1439 Projection: CAST(a AS LargeBinary)
1440 EmptyRelation: rows=0
1441 "
1442 )?;
1443
1444 let bool_expr = col("a").lt(lit(vec![8, 1, 8, 1]));
1447 let bool_plan = LogicalPlan::Projection(Projection::try_new(
1448 vec![bool_expr],
1449 Arc::clone(&empty),
1450 )?);
1451
1452 coerce_on_output_if_viewtype!(
1454 false,
1455 bool_plan.clone(),
1456 @r#"
1457 Projection: a < CAST(Binary("8,1,8,1") AS BinaryView)
1458 EmptyRelation: rows=0
1459 "#
1460 )?;
1461
1462 coerce_on_output_if_viewtype!(
1464 true,
1465 bool_plan.clone(),
1466 @r#"
1467 Projection: a < CAST(Binary("8,1,8,1") AS BinaryView)
1468 EmptyRelation: rows=0
1469 "#
1470 )?;
1471
1472 let sort_expr = expr.sort(true, true);
1475 let sort_plan = LogicalPlan::Sort(Sort {
1476 expr: vec![sort_expr],
1477 input: Arc::new(plan),
1478 fetch: None,
1479 });
1480
1481 coerce_on_output_if_viewtype!(
1483 false,
1484 sort_plan.clone(),
1485 @r"
1486 Sort: a ASC NULLS FIRST
1487 Projection: a
1488 EmptyRelation: rows=0
1489 "
1490 )?;
1491 coerce_on_output_if_viewtype!(
1493 true,
1494 sort_plan.clone(),
1495 @r"
1496 Projection: CAST(a AS LargeBinary)
1497 Sort: a ASC NULLS FIRST
1498 Projection: a
1499 EmptyRelation: rows=0
1500 "
1501 )?;
1502
1503 let plan = LogicalPlan::Projection(Projection::try_new(
1506 vec![col("a")],
1507 Arc::new(sort_plan),
1508 )?);
1509
1510 coerce_on_output_if_viewtype!(
1512 false,
1513 plan.clone(),
1514 @r"
1515 Projection: a
1516 Sort: a ASC NULLS FIRST
1517 Projection: a
1518 EmptyRelation: rows=0
1519 "
1520 )?;
1521
1522 coerce_on_output_if_viewtype!(
1524 true,
1525 plan.clone(),
1526 @r"
1527 Projection: CAST(a AS LargeBinary)
1528 Sort: a ASC NULLS FIRST
1529 Projection: a
1530 EmptyRelation: rows=0
1531 "
1532 )?;
1533
1534 Ok(())
1535 }
1536
1537 #[test]
1538 fn nested_case() -> Result<()> {
1539 let expr = col("a").lt(lit(2_u32));
1540 let empty = empty_with_type(DataType::Float64);
1541
1542 let plan = LogicalPlan::Projection(Projection::try_new(
1543 vec![expr.clone().or(expr)],
1544 empty,
1545 )?);
1546
1547 assert_analyzed_plan_eq!(
1548 plan,
1549 @r"
1550 Projection: a < CAST(UInt32(2) AS Float64) OR a < CAST(UInt32(2) AS Float64)
1551 EmptyRelation: rows=0
1552 "
1553 )
1554 }
1555
1556 #[derive(Debug, PartialEq, Eq, Hash)]
1557 struct TestScalarUDF {
1558 signature: Signature,
1559 }
1560
1561 impl ScalarUDFImpl for TestScalarUDF {
1562 fn as_any(&self) -> &dyn Any {
1563 self
1564 }
1565
1566 fn name(&self) -> &str {
1567 "TestScalarUDF"
1568 }
1569
1570 fn signature(&self) -> &Signature {
1571 &self.signature
1572 }
1573
1574 fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1575 Ok(Utf8)
1576 }
1577
1578 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1579 Ok(ColumnarValue::Scalar(ScalarValue::from("a")))
1580 }
1581 }
1582
1583 #[test]
1584 fn scalar_udf() -> Result<()> {
1585 let empty = empty();
1586
1587 let udf = ScalarUDF::from(TestScalarUDF {
1588 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1589 })
1590 .call(vec![lit(123_i32)]);
1591 let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?);
1592
1593 assert_analyzed_plan_eq!(
1594 plan,
1595 @r"
1596 Projection: TestScalarUDF(CAST(Int32(123) AS Float32))
1597 EmptyRelation: rows=0
1598 "
1599 )
1600 }
1601
1602 #[test]
1603 fn scalar_udf_invalid_input() -> Result<()> {
1604 let empty = empty();
1605 let udf = ScalarUDF::from(TestScalarUDF {
1606 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1607 })
1608 .call(vec![lit("Apple")]);
1609 Projection::try_new(vec![udf], empty)
1610 .expect_err("Expected an error due to incorrect function input");
1611
1612 Ok(())
1613 }
1614
1615 #[test]
1616 fn scalar_function() -> Result<()> {
1617 let empty = empty();
1619 let lit_expr = lit(10i64);
1620 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
1621 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1622 });
1623 let scalar_function_expr =
1624 Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![lit_expr]));
1625 let plan = LogicalPlan::Projection(Projection::try_new(
1626 vec![scalar_function_expr],
1627 empty,
1628 )?);
1629
1630 assert_analyzed_plan_eq!(
1631 plan,
1632 @r"
1633 Projection: TestScalarUDF(CAST(Int64(10) AS Float32))
1634 EmptyRelation: rows=0
1635 "
1636 )
1637 }
1638
1639 #[test]
1640 fn agg_udaf() -> Result<()> {
1641 let empty = empty();
1642 let my_avg = create_udaf(
1643 "MY_AVG",
1644 vec![DataType::Float64],
1645 Arc::new(DataType::Float64),
1646 Volatility::Immutable,
1647 Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
1648 Arc::new(vec![DataType::UInt64, DataType::Float64]),
1649 );
1650 let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1651 Arc::new(my_avg),
1652 vec![lit(10i64)],
1653 false,
1654 None,
1655 vec![],
1656 None,
1657 ));
1658 let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?);
1659
1660 assert_analyzed_plan_eq!(
1661 plan,
1662 @r"
1663 Projection: MY_AVG(CAST(Int64(10) AS Float64))
1664 EmptyRelation: rows=0
1665 "
1666 )
1667 }
1668
1669 #[test]
1670 fn agg_udaf_invalid_input() -> Result<()> {
1671 let empty = empty();
1672 let return_type = DataType::Float64;
1673 let accumulator: AccumulatorFactoryFunction =
1674 Arc::new(|_| Ok(Box::<AvgAccumulator>::default()));
1675 let my_avg = AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
1676 "MY_AVG",
1677 Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
1678 return_type,
1679 accumulator,
1680 vec![
1681 Field::new("count", DataType::UInt64, true).into(),
1682 Field::new("avg", DataType::Float64, true).into(),
1683 ],
1684 ));
1685 let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1686 Arc::new(my_avg),
1687 vec![lit("10")],
1688 false,
1689 None,
1690 vec![],
1691 None,
1692 ));
1693
1694 let err = Projection::try_new(vec![udaf], empty).err().unwrap();
1695 assert!(
1696 err.strip_backtrace().starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'MY_AVG' function: coercion from Utf8 to the signature Uniform(1, [Float64]) failed")
1697 );
1698 Ok(())
1699 }
1700
1701 #[test]
1702 fn agg_function_case() -> Result<()> {
1703 let empty = empty();
1704 let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1705 avg_udaf(),
1706 vec![lit(12f64)],
1707 false,
1708 None,
1709 vec![],
1710 None,
1711 ));
1712 let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1713
1714 assert_analyzed_plan_eq!(
1715 plan,
1716 @r"
1717 Projection: avg(Float64(12))
1718 EmptyRelation: rows=0
1719 "
1720 )?;
1721
1722 let empty = empty_with_type(DataType::Int32);
1723 let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1724 avg_udaf(),
1725 vec![cast(col("a"), DataType::Float64)],
1726 false,
1727 None,
1728 vec![],
1729 None,
1730 ));
1731 let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1732
1733 assert_analyzed_plan_eq!(
1734 plan,
1735 @r"
1736 Projection: avg(CAST(a AS Float64))
1737 EmptyRelation: rows=0
1738 "
1739 )
1740 }
1741
1742 #[test]
1743 fn agg_function_invalid_input_avg() -> Result<()> {
1744 let empty = empty();
1745 let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1746 avg_udaf(),
1747 vec![lit("1")],
1748 false,
1749 None,
1750 vec![],
1751 None,
1752 ));
1753 let err = Projection::try_new(vec![agg_expr], empty)
1754 .err()
1755 .unwrap()
1756 .strip_backtrace();
1757 assert!(err.starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'avg' function: coercion from Utf8 to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed"));
1758 Ok(())
1759 }
1760
1761 #[test]
1762 fn binary_op_date32_op_interval() -> Result<()> {
1763 let expr = cast(lit("1998-03-18"), DataType::Date32)
1765 + lit(ScalarValue::new_interval_dt(123, 456));
1766 let empty = empty();
1767 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1768
1769 assert_analyzed_plan_eq!(
1770 plan,
1771 @r#"
1772 Projection: CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("IntervalDayTime { days: 123, milliseconds: 456 }")
1773 EmptyRelation: rows=0
1774 "#
1775 )
1776 }
1777
1778 #[test]
1779 fn inlist_case() -> Result<()> {
1780 let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1782 let empty = empty_with_type(DataType::Int64);
1783 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1784 assert_analyzed_plan_eq!(
1785 plan,
1786 @r"
1787 Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)])
1788 EmptyRelation: rows=0
1789 ")?;
1790
1791 let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1793 let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1794 produce_one_row: false,
1795 schema: Arc::new(DFSchema::from_unqualified_fields(
1796 vec![Field::new("a", DataType::Decimal128(12, 4), true)].into(),
1797 std::collections::HashMap::new(),
1798 )?),
1799 }));
1800 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1801 assert_analyzed_plan_eq!(
1802 plan,
1803 @r"
1804 Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))])
1805 EmptyRelation: rows=0
1806 ")
1807 }
1808
1809 #[test]
1810 fn between_case() -> Result<()> {
1811 let expr = col("a").between(
1812 lit("2002-05-08"),
1813 cast(lit("2002-05-08"), DataType::Date32)
1815 + lit(ScalarValue::new_interval_ym(0, 1)),
1816 );
1817 let empty = empty_with_type(Utf8);
1818 let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1819
1820 assert_analyzed_plan_eq!(
1821 plan,
1822 @r#"
1823 Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) AND CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1")
1824 EmptyRelation: rows=0
1825 "#
1826 )
1827 }
1828
1829 #[test]
1830 fn between_infer_cheap_type() -> Result<()> {
1831 let expr = col("a").between(
1832 cast(lit("2002-05-08"), DataType::Date32)
1834 + lit(ScalarValue::new_interval_ym(0, 1)),
1835 lit("2002-12-08"),
1836 );
1837 let empty = empty_with_type(Utf8);
1838 let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1839
1840 assert_analyzed_plan_eq!(
1842 plan,
1843 @r#"
1844 Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1") AND CAST(Utf8("2002-12-08") AS Date32)
1845 EmptyRelation: rows=0
1846 "#
1847 )
1848 }
1849
1850 #[test]
1851 fn between_null() -> Result<()> {
1852 let expr = lit(ScalarValue::Null).between(lit(ScalarValue::Null), lit(2i64));
1853 let empty = empty();
1854 let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1855
1856 assert_analyzed_plan_eq!(
1857 plan,
1858 @r"
1859 Filter: CAST(NULL AS Int64) BETWEEN CAST(NULL AS Int64) AND Int64(2)
1860 EmptyRelation: rows=0
1861 "
1862 )
1863 }
1864
1865 #[test]
1866 fn is_bool_for_type_coercion() -> Result<()> {
1867 let expr = col("a").is_true();
1869 let empty = empty_with_type(DataType::Boolean);
1870 let plan =
1871 LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
1872
1873 assert_analyzed_plan_eq!(
1874 plan,
1875 @r"
1876 Projection: a IS TRUE
1877 EmptyRelation: rows=0
1878 "
1879 )?;
1880
1881 let empty = empty_with_type(DataType::Int64);
1882 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1883 assert_type_coercion_error(
1884 plan,
1885 "Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean"
1886 )?;
1887
1888 let expr = col("a").is_not_true();
1890 let empty = empty_with_type(DataType::Boolean);
1891 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1892
1893 assert_analyzed_plan_eq!(
1894 plan,
1895 @r"
1896 Projection: a IS NOT TRUE
1897 EmptyRelation: rows=0
1898 "
1899 )?;
1900
1901 let expr = col("a").is_false();
1903 let empty = empty_with_type(DataType::Boolean);
1904 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1905
1906 assert_analyzed_plan_eq!(
1907 plan,
1908 @r"
1909 Projection: a IS FALSE
1910 EmptyRelation: rows=0
1911 "
1912 )?;
1913
1914 let expr = col("a").is_not_false();
1916 let empty = empty_with_type(DataType::Boolean);
1917 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1918
1919 assert_analyzed_plan_eq!(
1920 plan,
1921 @r"
1922 Projection: a IS NOT FALSE
1923 EmptyRelation: rows=0
1924 "
1925 )
1926 }
1927
1928 #[test]
1929 fn like_for_type_coercion() -> Result<()> {
1930 let expr = Box::new(col("a"));
1932 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1933 let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1934 let empty = empty_with_type(Utf8);
1935 let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1936
1937 assert_analyzed_plan_eq!(
1938 plan,
1939 @r#"
1940 Projection: a LIKE Utf8("abc")
1941 EmptyRelation: rows=0
1942 "#
1943 )?;
1944
1945 let expr = Box::new(col("a"));
1946 let pattern = Box::new(lit(ScalarValue::Null));
1947 let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1948 let empty = empty_with_type(Utf8);
1949 let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1950
1951 assert_analyzed_plan_eq!(
1952 plan,
1953 @r"
1954 Projection: a LIKE CAST(NULL AS Utf8)
1955 EmptyRelation: rows=0
1956 "
1957 )?;
1958
1959 let expr = Box::new(col("a"));
1960 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1961 let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1962 let empty = empty_with_type(DataType::Int64);
1963 let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1964 assert_type_coercion_error(
1965 plan,
1966 "There isn't a common type to coerce Int64 and Utf8 in LIKE expression",
1967 )?;
1968
1969 let expr = Box::new(col("a"));
1971 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1972 let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
1973 let empty = empty_with_type(Utf8);
1974 let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
1975
1976 assert_analyzed_plan_eq!(
1977 plan,
1978 @r#"
1979 Projection: a ILIKE Utf8("abc")
1980 EmptyRelation: rows=0
1981 "#
1982 )?;
1983
1984 let expr = Box::new(col("a"));
1985 let pattern = Box::new(lit(ScalarValue::Null));
1986 let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
1987 let empty = empty_with_type(Utf8);
1988 let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
1989
1990 assert_analyzed_plan_eq!(
1991 plan,
1992 @r"
1993 Projection: a ILIKE CAST(NULL AS Utf8)
1994 EmptyRelation: rows=0
1995 "
1996 )?;
1997
1998 let expr = Box::new(col("a"));
1999 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
2000 let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
2001 let empty = empty_with_type(DataType::Int64);
2002 let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
2003 assert_type_coercion_error(
2004 plan,
2005 "There isn't a common type to coerce Int64 and Utf8 in ILIKE expression",
2006 )?;
2007
2008 Ok(())
2009 }
2010
2011 #[test]
2012 fn unknown_for_type_coercion() -> Result<()> {
2013 let expr = col("a").is_unknown();
2015 let empty = empty_with_type(DataType::Boolean);
2016 let plan =
2017 LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
2018
2019 assert_analyzed_plan_eq!(
2020 plan,
2021 @r"
2022 Projection: a IS UNKNOWN
2023 EmptyRelation: rows=0
2024 "
2025 )?;
2026
2027 let empty = empty_with_type(Utf8);
2028 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2029 assert_type_coercion_error(
2030 plan,
2031 "Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean"
2032 )?;
2033
2034 let expr = col("a").is_not_unknown();
2036 let empty = empty_with_type(DataType::Boolean);
2037 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2038
2039 assert_analyzed_plan_eq!(
2040 plan,
2041 @r"
2042 Projection: a IS NOT UNKNOWN
2043 EmptyRelation: rows=0
2044 "
2045 )
2046 }
2047
2048 #[test]
2049 fn concat_for_type_coercion() -> Result<()> {
2050 let empty = empty_with_type(Utf8);
2051 let args = [col("a"), lit("b"), lit(true), lit(false), lit(13)];
2052
2053 let expr = ScalarUDF::new_from_impl(TestScalarUDF {
2055 signature: Signature::variadic(vec![Utf8], Volatility::Immutable),
2056 })
2057 .call(args.to_vec());
2058 let plan =
2059 LogicalPlan::Projection(Projection::try_new(vec![expr], Arc::clone(&empty))?);
2060 assert_analyzed_plan_eq!(
2061 plan,
2062 @r#"
2063 Projection: TestScalarUDF(a, Utf8("b"), CAST(Boolean(true) AS Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS Utf8))
2064 EmptyRelation: rows=0
2065 "#
2066 )
2067 }
2068
2069 #[test]
2070 fn test_type_coercion_rewrite() -> Result<()> {
2071 let schema = Arc::new(DFSchema::from_unqualified_fields(
2073 vec![Field::new("a", DataType::Int64, true)].into(),
2074 std::collections::HashMap::new(),
2075 )?);
2076 let mut rewriter = TypeCoercionRewriter { schema: &schema };
2077 let expr = is_true(lit(12i32).gt(lit(13i64)));
2078 let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64)));
2079 let result = expr.rewrite(&mut rewriter).data()?;
2080 assert_eq!(expected, result);
2081
2082 let schema = Arc::new(DFSchema::from_unqualified_fields(
2084 vec![Field::new("a", DataType::Int64, true)].into(),
2085 std::collections::HashMap::new(),
2086 )?);
2087 let mut rewriter = TypeCoercionRewriter { schema: &schema };
2088 let expr = is_true(lit(12i32).eq(lit(13i64)));
2089 let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64)));
2090 let result = expr.rewrite(&mut rewriter).data()?;
2091 assert_eq!(expected, result);
2092
2093 let schema = Arc::new(DFSchema::from_unqualified_fields(
2095 vec![Field::new("a", DataType::Int64, true)].into(),
2096 std::collections::HashMap::new(),
2097 )?);
2098 let mut rewriter = TypeCoercionRewriter { schema: &schema };
2099 let expr = is_true(lit(12i32).lt(lit(13i64)));
2100 let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64)));
2101 let result = expr.rewrite(&mut rewriter).data()?;
2102 assert_eq!(expected, result);
2103
2104 Ok(())
2105 }
2106
2107 #[test]
2108 fn binary_op_date32_eq_ts() -> Result<()> {
2109 let expr = cast(
2110 lit("1998-03-18"),
2111 DataType::Timestamp(TimeUnit::Nanosecond, None),
2112 )
2113 .eq(cast(lit("1998-03-18"), DataType::Date32));
2114 let empty = empty();
2115 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2116
2117 assert_analyzed_plan_eq!(
2118 plan,
2119 @r#"
2120 Projection: CAST(Utf8("1998-03-18") AS Timestamp(ns)) = CAST(CAST(Utf8("1998-03-18") AS Date32) AS Timestamp(ns))
2121 EmptyRelation: rows=0
2122 "#
2123 )
2124 }
2125
2126 fn cast_if_not_same_type(
2127 expr: Box<Expr>,
2128 data_type: &DataType,
2129 schema: &DFSchemaRef,
2130 ) -> Box<Expr> {
2131 if &expr.get_type(schema).unwrap() != data_type {
2132 Box::new(cast(*expr, data_type.clone()))
2133 } else {
2134 expr
2135 }
2136 }
2137
2138 fn cast_helper(
2139 case: Case,
2140 case_when_type: &DataType,
2141 then_else_type: &DataType,
2142 schema: &DFSchemaRef,
2143 ) -> Case {
2144 let expr = case
2145 .expr
2146 .map(|e| cast_if_not_same_type(e, case_when_type, schema));
2147 let when_then_expr = case
2148 .when_then_expr
2149 .into_iter()
2150 .map(|(when, then)| {
2151 (
2152 cast_if_not_same_type(when, case_when_type, schema),
2153 cast_if_not_same_type(then, then_else_type, schema),
2154 )
2155 })
2156 .collect::<Vec<_>>();
2157 let else_expr = case
2158 .else_expr
2159 .map(|e| cast_if_not_same_type(e, then_else_type, schema));
2160
2161 Case {
2162 expr,
2163 when_then_expr,
2164 else_expr,
2165 }
2166 }
2167
2168 #[test]
2169 fn test_case_expression_coercion() -> Result<()> {
2170 let schema = Arc::new(DFSchema::from_unqualified_fields(
2171 vec![
2172 Field::new("boolean", DataType::Boolean, true),
2173 Field::new("integer", DataType::Int32, true),
2174 Field::new("float", DataType::Float32, true),
2175 Field::new(
2176 "timestamp",
2177 DataType::Timestamp(TimeUnit::Nanosecond, None),
2178 true,
2179 ),
2180 Field::new("date", DataType::Date32, true),
2181 Field::new(
2182 "interval",
2183 DataType::Interval(arrow::datatypes::IntervalUnit::MonthDayNano),
2184 true,
2185 ),
2186 Field::new("binary", DataType::Binary, true),
2187 Field::new("string", Utf8, true),
2188 Field::new("decimal", DataType::Decimal128(10, 10), true),
2189 ]
2190 .into(),
2191 std::collections::HashMap::new(),
2192 )?);
2193
2194 let case = Case {
2195 expr: None,
2196 when_then_expr: vec![
2197 (Box::new(col("boolean")), Box::new(col("integer"))),
2198 (Box::new(col("integer")), Box::new(col("float"))),
2199 (Box::new(col("string")), Box::new(col("string"))),
2200 ],
2201 else_expr: None,
2202 };
2203 let case_when_common_type = DataType::Boolean;
2204 let then_else_common_type = Utf8;
2205 let expected = cast_helper(
2206 case.clone(),
2207 &case_when_common_type,
2208 &then_else_common_type,
2209 &schema,
2210 );
2211 let actual = coerce_case_expression(case, &schema)?;
2212 assert_eq!(expected, actual);
2213
2214 let case = Case {
2215 expr: Some(Box::new(col("string"))),
2216 when_then_expr: vec![
2217 (Box::new(col("float")), Box::new(col("integer"))),
2218 (Box::new(col("integer")), Box::new(col("float"))),
2219 (Box::new(col("string")), Box::new(col("string"))),
2220 ],
2221 else_expr: Some(Box::new(col("string"))),
2222 };
2223 let case_when_common_type = Utf8;
2224 let then_else_common_type = Utf8;
2225 let expected = cast_helper(
2226 case.clone(),
2227 &case_when_common_type,
2228 &then_else_common_type,
2229 &schema,
2230 );
2231 let actual = coerce_case_expression(case, &schema)?;
2232 assert_eq!(expected, actual);
2233
2234 let case = Case {
2235 expr: Some(Box::new(col("interval"))),
2236 when_then_expr: vec![
2237 (Box::new(col("float")), Box::new(col("integer"))),
2238 (Box::new(col("binary")), Box::new(col("float"))),
2239 (Box::new(col("string")), Box::new(col("string"))),
2240 ],
2241 else_expr: Some(Box::new(col("string"))),
2242 };
2243 let err = coerce_case_expression(case, &schema).unwrap_err();
2244 assert_snapshot!(
2245 err.strip_backtrace(),
2246 @"Error during planning: Failed to coerce case (Interval(MonthDayNano)) and when (Float32, Binary, Utf8) to common types in CASE WHEN expression"
2247 );
2248
2249 let case = Case {
2250 expr: Some(Box::new(col("string"))),
2251 when_then_expr: vec![
2252 (Box::new(col("float")), Box::new(col("date"))),
2253 (Box::new(col("string")), Box::new(col("float"))),
2254 (Box::new(col("string")), Box::new(col("binary"))),
2255 ],
2256 else_expr: Some(Box::new(col("timestamp"))),
2257 };
2258 let err = coerce_case_expression(case, &schema).unwrap_err();
2259 assert_snapshot!(
2260 err.strip_backtrace(),
2261 @"Error during planning: Failed to coerce then (Date32, Float32, Binary) and else (Timestamp(ns)) to common types in CASE WHEN expression"
2262 );
2263
2264 Ok(())
2265 }
2266
2267 macro_rules! test_case_expression {
2268 ($expr:expr, $when_then:expr, $case_when_type:expr, $then_else_type:expr, $schema:expr) => {
2269 let case = Case {
2270 expr: $expr.map(|e| Box::new(col(e))),
2271 when_then_expr: $when_then,
2272 else_expr: None,
2273 };
2274
2275 let expected =
2276 cast_helper(case.clone(), &$case_when_type, &$then_else_type, &$schema);
2277
2278 let actual = coerce_case_expression(case, &$schema)?;
2279 assert_eq!(expected, actual);
2280 };
2281 }
2282
2283 #[test]
2284 fn tes_case_when_list() -> Result<()> {
2285 let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
2286 let schema = Arc::new(DFSchema::from_unqualified_fields(
2287 vec![
2288 Field::new(
2289 "large_list",
2290 DataType::LargeList(Arc::clone(&inner_field)),
2291 true,
2292 ),
2293 Field::new(
2294 "fixed_list",
2295 DataType::FixedSizeList(Arc::clone(&inner_field), 3),
2296 true,
2297 ),
2298 Field::new("list", DataType::List(inner_field), true),
2299 ]
2300 .into(),
2301 std::collections::HashMap::new(),
2302 )?);
2303
2304 test_case_expression!(
2305 Some("list"),
2306 vec![(Box::new(col("large_list")), Box::new(lit("1")))],
2307 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2308 Utf8,
2309 schema
2310 );
2311
2312 test_case_expression!(
2313 Some("large_list"),
2314 vec![(Box::new(col("list")), Box::new(lit("1")))],
2315 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2316 Utf8,
2317 schema
2318 );
2319
2320 test_case_expression!(
2321 Some("list"),
2322 vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
2323 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2324 Utf8,
2325 schema
2326 );
2327
2328 test_case_expression!(
2329 Some("fixed_list"),
2330 vec![(Box::new(col("list")), Box::new(lit("1")))],
2331 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2332 Utf8,
2333 schema
2334 );
2335
2336 test_case_expression!(
2337 Some("fixed_list"),
2338 vec![(Box::new(col("large_list")), Box::new(lit("1")))],
2339 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2340 Utf8,
2341 schema
2342 );
2343
2344 test_case_expression!(
2345 Some("large_list"),
2346 vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
2347 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2348 Utf8,
2349 schema
2350 );
2351 Ok(())
2352 }
2353
2354 #[test]
2355 fn test_then_else_list() -> Result<()> {
2356 let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
2357 let schema = Arc::new(DFSchema::from_unqualified_fields(
2358 vec![
2359 Field::new("boolean", DataType::Boolean, true),
2360 Field::new(
2361 "large_list",
2362 DataType::LargeList(Arc::clone(&inner_field)),
2363 true,
2364 ),
2365 Field::new(
2366 "fixed_list",
2367 DataType::FixedSizeList(Arc::clone(&inner_field), 3),
2368 true,
2369 ),
2370 Field::new("list", DataType::List(inner_field), true),
2371 ]
2372 .into(),
2373 std::collections::HashMap::new(),
2374 )?);
2375
2376 test_case_expression!(
2378 None::<String>,
2379 vec![
2380 (Box::new(col("boolean")), Box::new(col("large_list"))),
2381 (Box::new(col("boolean")), Box::new(col("list")))
2382 ],
2383 DataType::Boolean,
2384 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2385 schema
2386 );
2387
2388 test_case_expression!(
2389 None::<String>,
2390 vec![
2391 (Box::new(col("boolean")), Box::new(col("list"))),
2392 (Box::new(col("boolean")), Box::new(col("large_list")))
2393 ],
2394 DataType::Boolean,
2395 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2396 schema
2397 );
2398
2399 test_case_expression!(
2401 None::<String>,
2402 vec![
2403 (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2404 (Box::new(col("boolean")), Box::new(col("list")))
2405 ],
2406 DataType::Boolean,
2407 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2408 schema
2409 );
2410
2411 test_case_expression!(
2412 None::<String>,
2413 vec![
2414 (Box::new(col("boolean")), Box::new(col("list"))),
2415 (Box::new(col("boolean")), Box::new(col("fixed_list")))
2416 ],
2417 DataType::Boolean,
2418 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2419 schema
2420 );
2421
2422 test_case_expression!(
2424 None::<String>,
2425 vec![
2426 (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2427 (Box::new(col("boolean")), Box::new(col("large_list")))
2428 ],
2429 DataType::Boolean,
2430 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2431 schema
2432 );
2433
2434 test_case_expression!(
2435 None::<String>,
2436 vec![
2437 (Box::new(col("boolean")), Box::new(col("large_list"))),
2438 (Box::new(col("boolean")), Box::new(col("fixed_list")))
2439 ],
2440 DataType::Boolean,
2441 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2442 schema
2443 );
2444 Ok(())
2445 }
2446
2447 #[test]
2448 fn test_map_with_diff_name() -> Result<()> {
2449 let mut builder = SchemaBuilder::new();
2450 builder.push(Field::new("key", Utf8, false));
2451 builder.push(Field::new("value", DataType::Float64, true));
2452 let struct_fields = builder.finish().fields;
2453
2454 let fields =
2455 Field::new("entries", DataType::Struct(struct_fields.clone()), false);
2456 let map_type_entries = DataType::Map(Arc::new(fields), false);
2457
2458 let fields = Field::new("key_value", DataType::Struct(struct_fields), false);
2459 let may_type_custom = DataType::Map(Arc::new(fields), false);
2460
2461 let expr = col("a").eq(cast(col("a"), may_type_custom));
2462 let empty = empty_with_type(map_type_entries);
2463 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2464
2465 assert_analyzed_plan_eq!(
2466 plan,
2467 @r#"
2468 Projection: a = CAST(CAST(a AS Map("key_value": Struct("key": Utf8, "value": nullable Float64), unsorted)) AS Map("entries": Struct("key": Utf8, "value": nullable Float64), unsorted))
2469 EmptyRelation: rows=0
2470 "#
2471 )
2472 }
2473
2474 #[test]
2475 fn interval_plus_timestamp() -> Result<()> {
2476 let expr = Expr::BinaryExpr(BinaryExpr::new(
2478 Box::new(lit(ScalarValue::IntervalYearMonth(Some(12)))),
2479 Operator::Plus,
2480 Box::new(cast(
2481 lit("2000-01-01T00:00:00"),
2482 DataType::Timestamp(TimeUnit::Nanosecond, None),
2483 )),
2484 ));
2485 let empty = empty();
2486 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2487
2488 assert_analyzed_plan_eq!(
2489 plan,
2490 @r#"
2491 Projection: IntervalYearMonth("12") + CAST(Utf8("2000-01-01T00:00:00") AS Timestamp(ns))
2492 EmptyRelation: rows=0
2493 "#
2494 )
2495 }
2496
2497 #[test]
2498 fn timestamp_subtract_timestamp() -> Result<()> {
2499 let expr = Expr::BinaryExpr(BinaryExpr::new(
2500 Box::new(cast(
2501 lit("1998-03-18"),
2502 DataType::Timestamp(TimeUnit::Nanosecond, None),
2503 )),
2504 Operator::Minus,
2505 Box::new(cast(
2506 lit("1998-03-18"),
2507 DataType::Timestamp(TimeUnit::Nanosecond, None),
2508 )),
2509 ));
2510 let empty = empty();
2511 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2512
2513 assert_analyzed_plan_eq!(
2514 plan,
2515 @r#"
2516 Projection: CAST(Utf8("1998-03-18") AS Timestamp(ns)) - CAST(Utf8("1998-03-18") AS Timestamp(ns))
2517 EmptyRelation: rows=0
2518 "#
2519 )
2520 }
2521
2522 #[test]
2523 fn in_subquery_cast_subquery() -> Result<()> {
2524 let empty_int32 = empty_with_type(DataType::Int32);
2525 let empty_int64 = empty_with_type(DataType::Int64);
2526
2527 let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2528 Box::new(col("a")),
2529 Subquery {
2530 subquery: empty_int32,
2531 outer_ref_columns: vec![],
2532 spans: Spans::new(),
2533 },
2534 false,
2535 ));
2536 let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int64)?);
2537 assert_analyzed_plan_eq!(
2540 plan,
2541 @r"
2542 Filter: a IN (<subquery>)
2543 Subquery:
2544 Projection: CAST(a AS Int64)
2545 EmptyRelation: rows=0
2546 EmptyRelation: rows=0
2547 "
2548 )
2549 }
2550
2551 #[test]
2552 fn in_subquery_cast_expr() -> Result<()> {
2553 let empty_int32 = empty_with_type(DataType::Int32);
2554 let empty_int64 = empty_with_type(DataType::Int64);
2555
2556 let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2557 Box::new(col("a")),
2558 Subquery {
2559 subquery: empty_int64,
2560 outer_ref_columns: vec![],
2561 spans: Spans::new(),
2562 },
2563 false,
2564 ));
2565 let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int32)?);
2566
2567 assert_analyzed_plan_eq!(
2569 plan,
2570 @r"
2571 Filter: CAST(a AS Int64) IN (<subquery>)
2572 Subquery:
2573 EmptyRelation: rows=0
2574 EmptyRelation: rows=0
2575 "
2576 )
2577 }
2578
2579 #[test]
2580 fn in_subquery_cast_all() -> Result<()> {
2581 let empty_inside = empty_with_type(DataType::Decimal128(10, 5));
2582 let empty_outside = empty_with_type(DataType::Decimal128(8, 8));
2583
2584 let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2585 Box::new(col("a")),
2586 Subquery {
2587 subquery: empty_inside,
2588 outer_ref_columns: vec![],
2589 spans: Spans::new(),
2590 },
2591 false,
2592 ));
2593 let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_outside)?);
2594
2595 assert_analyzed_plan_eq!(
2597 plan,
2598 @r"
2599 Filter: CAST(a AS Decimal128(13, 8)) IN (<subquery>)
2600 Subquery:
2601 Projection: CAST(a AS Decimal128(13, 8))
2602 EmptyRelation: rows=0
2603 EmptyRelation: rows=0
2604 "
2605 )
2606 }
2607}