1use std::sync::Arc;
21
22use datafusion_expr::binary::BinaryTypeCoercer;
23use itertools::izip;
24
25use arrow::datatypes::{DataType, Field, IntervalUnit, Schema};
26
27use crate::analyzer::AnalyzerRule;
28use crate::utils::NamePreserver;
29use datafusion_common::config::ConfigOptions;
30use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
31use datafusion_common::{
32 exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column,
33 DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference,
34};
35use datafusion_expr::expr::{
36 self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, InList,
37 InSubquery, Like, ScalarFunction, Sort, WindowFunction,
38};
39use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
40use datafusion_expr::expr_schema::cast_subquery;
41use datafusion_expr::logical_plan::Subquery;
42use datafusion_expr::type_coercion::binary::{comparison_coercion, like_coercion};
43use datafusion_expr::type_coercion::functions::{
44 data_types_with_scalar_udf, fields_with_aggregate_udf,
45};
46use datafusion_expr::type_coercion::other::{
47 get_coerce_type_for_case_expression, get_coerce_type_for_list,
48};
49use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_utf8view_or_large_utf8};
50use datafusion_expr::utils::merge_schema;
51use datafusion_expr::{
52 is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not,
53 AggregateUDF, Expr, ExprSchemable, Join, Limit, LogicalPlan, Operator, Projection,
54 ScalarUDF, Union, WindowFrame, WindowFrameBound, WindowFrameUnits,
55};
56
57#[derive(Default, Debug)]
60pub struct TypeCoercion {}
61
62impl TypeCoercion {
63 pub fn new() -> Self {
64 Self {}
65 }
66}
67
68fn coerce_output(plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
70 if !config.optimizer.expand_views_at_output {
71 return Ok(plan);
72 }
73
74 let outer_refs = plan.expressions();
75 if outer_refs.is_empty() {
76 return Ok(plan);
77 }
78
79 if let Some(dfschema) = transform_schema_to_nonview(plan.schema()) {
80 coerce_plan_expr_for_schema(plan, &dfschema?)
81 } else {
82 Ok(plan)
83 }
84}
85
86impl AnalyzerRule for TypeCoercion {
87 fn name(&self) -> &str {
88 "type_coercion"
89 }
90
91 fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
92 let empty_schema = DFSchema::empty();
93
94 let transformed_plan = plan
96 .transform_up_with_subqueries(|plan| analyze_internal(&empty_schema, plan))?
97 .data;
98
99 coerce_output(transformed_plan, config)
101 }
102}
103
104fn analyze_internal(
108 external_schema: &DFSchema,
109 plan: LogicalPlan,
110) -> Result<Transformed<LogicalPlan>> {
111 let mut schema = merge_schema(&plan.inputs());
114
115 if let LogicalPlan::TableScan(ts) = &plan {
116 let source_schema = DFSchema::try_from_qualified_schema(
117 ts.table_name.clone(),
118 &ts.source.schema(),
119 )?;
120 schema.merge(&source_schema);
121 }
122
123 schema.merge(external_schema);
127
128 let plan = if let LogicalPlan::Filter(mut filter) = plan {
130 filter.predicate = filter.predicate.cast_to(&DataType::Boolean, &schema)?;
131 LogicalPlan::Filter(filter)
132 } else {
133 plan
134 };
135
136 let mut expr_rewrite = TypeCoercionRewriter::new(&schema);
137
138 let name_preserver = NamePreserver::new(&plan);
139 plan.map_expressions(|expr| {
141 let original_name = name_preserver.save(&expr);
142 expr.rewrite(&mut expr_rewrite)
143 .map(|transformed| transformed.update_data(|e| original_name.restore(e)))
144 })?
145 .map_data(|plan| expr_rewrite.coerce_plan(plan))?
147 .map_data(|plan| plan.recompute_schema())
149}
150
151pub struct TypeCoercionRewriter<'a> {
153 pub(crate) schema: &'a DFSchema,
154}
155
156impl<'a> TypeCoercionRewriter<'a> {
157 pub fn new(schema: &'a DFSchema) -> Self {
160 Self { schema }
161 }
162
163 pub fn coerce_plan(&mut self, plan: LogicalPlan) -> Result<LogicalPlan> {
168 match plan {
169 LogicalPlan::Join(join) => self.coerce_join(join),
170 LogicalPlan::Union(union) => Self::coerce_union(union),
171 LogicalPlan::Limit(limit) => Self::coerce_limit(limit),
172 _ => Ok(plan),
173 }
174 }
175
176 pub fn coerce_join(&mut self, mut join: Join) -> Result<LogicalPlan> {
185 join.on = join
186 .on
187 .into_iter()
188 .map(|(lhs, rhs)| {
189 let left_schema = join.left.schema();
192 let right_schema = join.right.schema();
193 let (lhs, rhs) = self.coerce_binary_op(
194 lhs,
195 left_schema,
196 Operator::Eq,
197 rhs,
198 right_schema,
199 )?;
200 Ok((lhs, rhs))
201 })
202 .collect::<Result<Vec<_>>>()?;
203
204 join.filter = join
206 .filter
207 .map(|expr| self.coerce_join_filter(expr))
208 .transpose()?;
209
210 Ok(LogicalPlan::Join(join))
211 }
212
213 pub fn coerce_union(union_plan: Union) -> Result<LogicalPlan> {
216 let union_schema = Arc::new(coerce_union_schema_with_schema(
217 &union_plan.inputs,
218 &union_plan.schema,
219 )?);
220 let new_inputs = union_plan
221 .inputs
222 .into_iter()
223 .map(|p| {
224 let plan =
225 coerce_plan_expr_for_schema(Arc::unwrap_or_clone(p), &union_schema)?;
226 match plan {
227 LogicalPlan::Projection(Projection { expr, input, .. }) => {
228 Ok(Arc::new(project_with_column_index(
229 expr,
230 input,
231 Arc::clone(&union_schema),
232 )?))
233 }
234 other_plan => Ok(Arc::new(other_plan)),
235 }
236 })
237 .collect::<Result<Vec<_>>>()?;
238 Ok(LogicalPlan::Union(Union {
239 inputs: new_inputs,
240 schema: union_schema,
241 }))
242 }
243
244 fn coerce_limit(limit: Limit) -> Result<LogicalPlan> {
246 fn coerce_limit_expr(
247 expr: Expr,
248 schema: &DFSchema,
249 expr_name: &str,
250 ) -> Result<Expr> {
251 let dt = expr.get_type(schema)?;
252 if dt.is_integer() || dt.is_null() {
253 expr.cast_to(&DataType::Int64, schema)
254 } else {
255 plan_err!("Expected {expr_name} to be an integer or null, but got {dt:?}")
256 }
257 }
258
259 let empty_schema = DFSchema::empty();
260 let new_fetch = limit
261 .fetch
262 .map(|expr| coerce_limit_expr(*expr, &empty_schema, "LIMIT"))
263 .transpose()?;
264 let new_skip = limit
265 .skip
266 .map(|expr| coerce_limit_expr(*expr, &empty_schema, "OFFSET"))
267 .transpose()?;
268 Ok(LogicalPlan::Limit(Limit {
269 input: limit.input,
270 fetch: new_fetch.map(Box::new),
271 skip: new_skip.map(Box::new),
272 }))
273 }
274
275 fn coerce_join_filter(&self, expr: Expr) -> Result<Expr> {
276 let expr_type = expr.get_type(self.schema)?;
277 match expr_type {
278 DataType::Boolean => Ok(expr),
279 DataType::Null => expr.cast_to(&DataType::Boolean, self.schema),
280 other => plan_err!("Join condition must be boolean type, but got {other:?}"),
281 }
282 }
283
284 fn coerce_binary_op(
285 &self,
286 left: Expr,
287 left_schema: &DFSchema,
288 op: Operator,
289 right: Expr,
290 right_schema: &DFSchema,
291 ) -> Result<(Expr, Expr)> {
292 let (left_type, right_type) = BinaryTypeCoercer::new(
293 &left.get_type(left_schema)?,
294 &op,
295 &right.get_type(right_schema)?,
296 )
297 .get_input_types()?;
298
299 Ok((
300 left.cast_to(&left_type, left_schema)?,
301 right.cast_to(&right_type, right_schema)?,
302 ))
303 }
304}
305
306impl TreeNodeRewriter for TypeCoercionRewriter<'_> {
307 type Node = Expr;
308
309 fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
310 match expr {
311 Expr::Unnest(_) => not_impl_err!(
312 "Unnest should be rewritten to LogicalPlan::Unnest before type coercion"
313 ),
314 Expr::ScalarSubquery(Subquery {
315 subquery,
316 outer_ref_columns,
317 spans,
318 }) => {
319 let new_plan =
320 analyze_internal(self.schema, Arc::unwrap_or_clone(subquery))?.data;
321 Ok(Transformed::yes(Expr::ScalarSubquery(Subquery {
322 subquery: Arc::new(new_plan),
323 outer_ref_columns,
324 spans,
325 })))
326 }
327 Expr::Exists(Exists { subquery, negated }) => {
328 let new_plan = analyze_internal(
329 self.schema,
330 Arc::unwrap_or_clone(subquery.subquery),
331 )?
332 .data;
333 Ok(Transformed::yes(Expr::Exists(Exists {
334 subquery: Subquery {
335 subquery: Arc::new(new_plan),
336 outer_ref_columns: subquery.outer_ref_columns,
337 spans: subquery.spans,
338 },
339 negated,
340 })))
341 }
342 Expr::InSubquery(InSubquery {
343 expr,
344 subquery,
345 negated,
346 }) => {
347 let new_plan = analyze_internal(
348 self.schema,
349 Arc::unwrap_or_clone(subquery.subquery),
350 )?
351 .data;
352 let expr_type = expr.get_type(self.schema)?;
353 let subquery_type = new_plan.schema().field(0).data_type();
354 let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(plan_datafusion_err!(
355 "expr type {expr_type:?} can't cast to {subquery_type:?} in InSubquery"
356 ),
357 )?;
358 let new_subquery = Subquery {
359 subquery: Arc::new(new_plan),
360 outer_ref_columns: subquery.outer_ref_columns,
361 spans: subquery.spans,
362 };
363 Ok(Transformed::yes(Expr::InSubquery(InSubquery::new(
364 Box::new(expr.cast_to(&common_type, self.schema)?),
365 cast_subquery(new_subquery, &common_type)?,
366 negated,
367 ))))
368 }
369 Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op(
370 *expr,
371 self.schema,
372 )?))),
373 Expr::IsTrue(expr) => Ok(Transformed::yes(is_true(
374 get_casted_expr_for_bool_op(*expr, self.schema)?,
375 ))),
376 Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true(
377 get_casted_expr_for_bool_op(*expr, self.schema)?,
378 ))),
379 Expr::IsFalse(expr) => Ok(Transformed::yes(is_false(
380 get_casted_expr_for_bool_op(*expr, self.schema)?,
381 ))),
382 Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false(
383 get_casted_expr_for_bool_op(*expr, self.schema)?,
384 ))),
385 Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown(
386 get_casted_expr_for_bool_op(*expr, self.schema)?,
387 ))),
388 Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown(
389 get_casted_expr_for_bool_op(*expr, self.schema)?,
390 ))),
391 Expr::Like(Like {
392 negated,
393 expr,
394 pattern,
395 escape_char,
396 case_insensitive,
397 }) => {
398 let left_type = expr.get_type(self.schema)?;
399 let right_type = pattern.get_type(self.schema)?;
400 let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| {
401 let op_name = if case_insensitive {
402 "ILIKE"
403 } else {
404 "LIKE"
405 };
406 plan_datafusion_err!(
407 "There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression"
408 )
409 })?;
410 let expr = match left_type {
411 DataType::Dictionary(_, inner) if *inner == DataType::Utf8 => expr,
412 _ => Box::new(expr.cast_to(&coerced_type, self.schema)?),
413 };
414 let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?);
415 Ok(Transformed::yes(Expr::Like(Like::new(
416 negated,
417 expr,
418 pattern,
419 escape_char,
420 case_insensitive,
421 ))))
422 }
423 Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
424 let (left, right) =
425 self.coerce_binary_op(*left, self.schema, op, *right, self.schema)?;
426 Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new(
427 Box::new(left),
428 op,
429 Box::new(right),
430 ))))
431 }
432 Expr::Between(Between {
433 expr,
434 negated,
435 low,
436 high,
437 }) => {
438 let expr_type = expr.get_type(self.schema)?;
439 let low_type = low.get_type(self.schema)?;
440 let low_coerced_type = comparison_coercion(&expr_type, &low_type)
441 .ok_or_else(|| {
442 DataFusionError::Internal(format!(
443 "Failed to coerce types {expr_type} and {low_type} in BETWEEN expression"
444 ))
445 })?;
446 let high_type = high.get_type(self.schema)?;
447 let high_coerced_type = comparison_coercion(&expr_type, &high_type)
448 .ok_or_else(|| {
449 DataFusionError::Internal(format!(
450 "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
451 ))
452 })?;
453 let coercion_type =
454 comparison_coercion(&low_coerced_type, &high_coerced_type)
455 .ok_or_else(|| {
456 DataFusionError::Internal(format!(
457 "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
458 ))
459 })?;
460 Ok(Transformed::yes(Expr::Between(Between::new(
461 Box::new(expr.cast_to(&coercion_type, self.schema)?),
462 negated,
463 Box::new(low.cast_to(&coercion_type, self.schema)?),
464 Box::new(high.cast_to(&coercion_type, self.schema)?),
465 ))))
466 }
467 Expr::InList(InList {
468 expr,
469 list,
470 negated,
471 }) => {
472 let expr_data_type = expr.get_type(self.schema)?;
473 let list_data_types = list
474 .iter()
475 .map(|list_expr| list_expr.get_type(self.schema))
476 .collect::<Result<Vec<_>>>()?;
477 let result_type =
478 get_coerce_type_for_list(&expr_data_type, &list_data_types);
479 match result_type {
480 None => plan_err!(
481 "Can not find compatible types to compare {expr_data_type:?} with {list_data_types:?}"
482 ),
483 Some(coerced_type) => {
484 let cast_expr = expr.cast_to(&coerced_type, self.schema)?;
486 let cast_list_expr = list
487 .into_iter()
488 .map(|list_expr| {
489 list_expr.cast_to(&coerced_type, self.schema)
490 })
491 .collect::<Result<Vec<_>>>()?;
492 Ok(Transformed::yes(Expr::InList(InList ::new(
493 Box::new(cast_expr),
494 cast_list_expr,
495 negated,
496 ))))
497 }
498 }
499 }
500 Expr::Case(case) => {
501 let case = coerce_case_expression(case, self.schema)?;
502 Ok(Transformed::yes(Expr::Case(case)))
503 }
504 Expr::ScalarFunction(ScalarFunction { func, args }) => {
505 let new_expr = coerce_arguments_for_signature_with_scalar_udf(
506 args,
507 self.schema,
508 &func,
509 )?;
510 Ok(Transformed::yes(Expr::ScalarFunction(
511 ScalarFunction::new_udf(func, new_expr),
512 )))
513 }
514 Expr::AggregateFunction(expr::AggregateFunction {
515 func,
516 params:
517 AggregateFunctionParams {
518 args,
519 distinct,
520 filter,
521 order_by,
522 null_treatment,
523 },
524 }) => {
525 let new_expr = coerce_arguments_for_signature_with_aggregate_udf(
526 args,
527 self.schema,
528 &func,
529 )?;
530 Ok(Transformed::yes(Expr::AggregateFunction(
531 expr::AggregateFunction::new_udf(
532 func,
533 new_expr,
534 distinct,
535 filter,
536 order_by,
537 null_treatment,
538 ),
539 )))
540 }
541 Expr::WindowFunction(window_fun) => {
542 let WindowFunction {
543 fun,
544 params:
545 expr::WindowFunctionParams {
546 args,
547 partition_by,
548 order_by,
549 window_frame,
550 filter,
551 null_treatment,
552 distinct,
553 },
554 } = *window_fun;
555 let window_frame =
556 coerce_window_frame(window_frame, self.schema, &order_by)?;
557
558 let args = match &fun {
559 expr::WindowFunctionDefinition::AggregateUDF(udf) => {
560 coerce_arguments_for_signature_with_aggregate_udf(
561 args,
562 self.schema,
563 udf,
564 )?
565 }
566 _ => args,
567 };
568
569 let new_expr = Expr::from(WindowFunction {
570 fun,
571 params: expr::WindowFunctionParams {
572 args,
573 partition_by,
574 order_by,
575 window_frame,
576 filter,
577 null_treatment,
578 distinct,
579 },
580 });
581 Ok(Transformed::yes(new_expr))
582 }
583 #[expect(deprecated)]
585 Expr::Alias(_)
586 | Expr::Column(_)
587 | Expr::ScalarVariable(_, _)
588 | Expr::Literal(_, _)
589 | Expr::SimilarTo(_)
590 | Expr::IsNotNull(_)
591 | Expr::IsNull(_)
592 | Expr::Negative(_)
593 | Expr::Cast(_)
594 | Expr::TryCast(_)
595 | Expr::Wildcard { .. }
596 | Expr::GroupingSet(_)
597 | Expr::Placeholder(_)
598 | Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)),
599 }
600 }
601}
602
603fn transform_schema_to_nonview(dfschema: &DFSchemaRef) -> Option<Result<DFSchema>> {
605 let metadata = dfschema.as_arrow().metadata.clone();
606 let mut transformed = false;
607
608 let (qualifiers, transformed_fields): (Vec<Option<TableReference>>, Vec<Arc<Field>>) =
609 dfschema
610 .iter()
611 .map(|(qualifier, field)| match field.data_type() {
612 DataType::Utf8View => {
613 transformed = true;
614 (
615 qualifier.cloned() as Option<TableReference>,
616 Arc::new(Field::new(
617 field.name(),
618 DataType::LargeUtf8,
619 field.is_nullable(),
620 )),
621 )
622 }
623 DataType::BinaryView => {
624 transformed = true;
625 (
626 qualifier.cloned() as Option<TableReference>,
627 Arc::new(Field::new(
628 field.name(),
629 DataType::LargeBinary,
630 field.is_nullable(),
631 )),
632 )
633 }
634 _ => (
635 qualifier.cloned() as Option<TableReference>,
636 Arc::clone(field),
637 ),
638 })
639 .unzip();
640
641 if !transformed {
642 return None;
643 }
644
645 let schema = Schema::new_with_metadata(transformed_fields, metadata);
646 Some(DFSchema::from_field_specific_qualified_schema(
647 qualifiers,
648 &Arc::new(schema),
649 ))
650}
651
652fn coerce_scalar(target_type: &DataType, value: &ScalarValue) -> Result<ScalarValue> {
655 match value {
656 ScalarValue::Utf8(Some(val)) => {
658 ScalarValue::try_from_string(val.clone(), target_type)
659 }
660 s => {
661 if s.is_null() {
662 ScalarValue::try_from(target_type)
664 } else {
665 Ok(s.clone())
669 }
670 }
671 }
672}
673
674fn coerce_scalar_range_aware(
681 target_type: &DataType,
682 value: &ScalarValue,
683) -> Result<ScalarValue> {
684 coerce_scalar(target_type, value).or_else(|err| {
685 if let Some(largest_type) = get_widest_type_in_family(target_type) {
687 coerce_scalar(largest_type, value).map_or_else(
688 |_| exec_err!("Cannot cast {value:?} to {target_type:?}"),
689 |_| ScalarValue::try_from(target_type),
690 )
691 } else {
692 Err(err)
693 }
694 })
695}
696
697fn get_widest_type_in_family(given_type: &DataType) -> Option<&DataType> {
701 match given_type {
702 DataType::UInt8 | DataType::UInt16 | DataType::UInt32 => Some(&DataType::UInt64),
703 DataType::Int8 | DataType::Int16 | DataType::Int32 => Some(&DataType::Int64),
704 DataType::Float16 | DataType::Float32 => Some(&DataType::Float64),
705 _ => None,
706 }
707}
708
709fn coerce_frame_bound(
711 target_type: &DataType,
712 bound: WindowFrameBound,
713) -> Result<WindowFrameBound> {
714 match bound {
715 WindowFrameBound::Preceding(v) => {
716 coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Preceding)
717 }
718 WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow),
719 WindowFrameBound::Following(v) => {
720 coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Following)
721 }
722 }
723}
724
725fn extract_window_frame_target_type(col_type: &DataType) -> Result<DataType> {
726 if col_type.is_numeric()
727 || is_utf8_or_utf8view_or_large_utf8(col_type)
728 || matches!(col_type, DataType::List(_))
729 || matches!(col_type, DataType::LargeList(_))
730 || matches!(col_type, DataType::FixedSizeList(_, _))
731 || matches!(col_type, DataType::Null)
732 || matches!(col_type, DataType::Boolean)
733 {
734 Ok(col_type.clone())
735 } else if is_datetime(col_type) {
736 Ok(DataType::Interval(IntervalUnit::MonthDayNano))
737 } else if let DataType::Dictionary(_, value_type) = col_type {
738 extract_window_frame_target_type(value_type)
739 } else {
740 internal_err!("Cannot run range queries on datatype: {col_type:?}")
741 }
742}
743
744fn coerce_window_frame(
747 window_frame: WindowFrame,
748 schema: &DFSchema,
749 expressions: &[Sort],
750) -> Result<WindowFrame> {
751 let mut window_frame = window_frame;
752 let target_type = match window_frame.units {
753 WindowFrameUnits::Range => {
754 let current_types = expressions
755 .first()
756 .map(|s| s.expr.get_type(schema))
757 .transpose()?;
758 if let Some(col_type) = current_types {
759 extract_window_frame_target_type(&col_type)?
760 } else {
761 return internal_err!("ORDER BY column cannot be empty");
762 }
763 }
764 WindowFrameUnits::Rows | WindowFrameUnits::Groups => DataType::UInt64,
765 };
766 window_frame.start_bound =
767 coerce_frame_bound(&target_type, window_frame.start_bound)?;
768 window_frame.end_bound = coerce_frame_bound(&target_type, window_frame.end_bound)?;
769 Ok(window_frame)
770}
771
772fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result<Expr> {
775 let left_type = expr.get_type(schema)?;
776 BinaryTypeCoercer::new(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)
777 .get_input_types()?;
778 expr.cast_to(&DataType::Boolean, schema)
779}
780
781fn coerce_arguments_for_signature_with_scalar_udf(
786 expressions: Vec<Expr>,
787 schema: &DFSchema,
788 func: &ScalarUDF,
789) -> Result<Vec<Expr>> {
790 if expressions.is_empty() {
791 return Ok(expressions);
792 }
793
794 let current_types = expressions
795 .iter()
796 .map(|e| e.get_type(schema))
797 .collect::<Result<Vec<_>>>()?;
798
799 let new_types = data_types_with_scalar_udf(¤t_types, func)?;
800
801 expressions
802 .into_iter()
803 .enumerate()
804 .map(|(i, expr)| expr.cast_to(&new_types[i], schema))
805 .collect()
806}
807
808fn coerce_arguments_for_signature_with_aggregate_udf(
813 expressions: Vec<Expr>,
814 schema: &DFSchema,
815 func: &AggregateUDF,
816) -> Result<Vec<Expr>> {
817 if expressions.is_empty() {
818 return Ok(expressions);
819 }
820
821 let current_fields = expressions
822 .iter()
823 .map(|e| e.to_field(schema).map(|(_, f)| f))
824 .collect::<Result<Vec<_>>>()?;
825
826 let new_types = fields_with_aggregate_udf(¤t_fields, func)?
827 .into_iter()
828 .map(|f| f.data_type().clone())
829 .collect::<Vec<_>>();
830
831 expressions
832 .into_iter()
833 .enumerate()
834 .map(|(i, expr)| expr.cast_to(&new_types[i], schema))
835 .collect()
836}
837
838fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result<Case> {
839 let case_type = case
871 .expr
872 .as_ref()
873 .map(|expr| expr.get_type(schema))
874 .transpose()?;
875 let then_types = case
876 .when_then_expr
877 .iter()
878 .map(|(_when, then)| then.get_type(schema))
879 .collect::<Result<Vec<_>>>()?;
880 let else_type = case
881 .else_expr
882 .as_ref()
883 .map(|expr| expr.get_type(schema))
884 .transpose()?;
885
886 let case_when_coerce_type = case_type
888 .as_ref()
889 .map(|case_type| {
890 let when_types = case
891 .when_then_expr
892 .iter()
893 .map(|(when, _then)| when.get_type(schema))
894 .collect::<Result<Vec<_>>>()?;
895 let coerced_type =
896 get_coerce_type_for_case_expression(&when_types, Some(case_type));
897 coerced_type.ok_or_else(|| {
898 plan_datafusion_err!(
899 "Failed to coerce case ({case_type:?}) and when ({when_types:?}) \
900 to common types in CASE WHEN expression"
901 )
902 })
903 })
904 .transpose()?;
905 let then_else_coerce_type =
906 get_coerce_type_for_case_expression(&then_types, else_type.as_ref()).ok_or_else(
907 || {
908 plan_datafusion_err!(
909 "Failed to coerce then ({then_types:?}) and else ({else_type:?}) \
910 to common types in CASE WHEN expression"
911 )
912 },
913 )?;
914
915 let case_expr = case
917 .expr
918 .zip(case_when_coerce_type.as_ref())
919 .map(|(case_expr, coercible_type)| case_expr.cast_to(coercible_type, schema))
920 .transpose()?
921 .map(Box::new);
922 let when_then = case
923 .when_then_expr
924 .into_iter()
925 .map(|(when, then)| {
926 let when_type = case_when_coerce_type.as_ref().unwrap_or(&DataType::Boolean);
927 let when = when.cast_to(when_type, schema).map_err(|e| {
928 DataFusionError::Context(
929 format!(
930 "WHEN expressions in CASE couldn't be \
931 converted to common type ({when_type})"
932 ),
933 Box::new(e),
934 )
935 })?;
936 let then = then.cast_to(&then_else_coerce_type, schema)?;
937 Ok((Box::new(when), Box::new(then)))
938 })
939 .collect::<Result<Vec<_>>>()?;
940 let else_expr = case
941 .else_expr
942 .map(|expr| expr.cast_to(&then_else_coerce_type, schema))
943 .transpose()?
944 .map(Box::new);
945
946 Ok(Case::new(case_expr, when_then, else_expr))
947}
948
949pub fn coerce_union_schema(inputs: &[Arc<LogicalPlan>]) -> Result<DFSchema> {
991 coerce_union_schema_with_schema(&inputs[1..], inputs[0].schema())
992}
993fn coerce_union_schema_with_schema(
994 inputs: &[Arc<LogicalPlan>],
995 base_schema: &DFSchemaRef,
996) -> Result<DFSchema> {
997 let mut union_datatypes = base_schema
998 .fields()
999 .iter()
1000 .map(|f| f.data_type().clone())
1001 .collect::<Vec<_>>();
1002 let mut union_nullabilities = base_schema
1003 .fields()
1004 .iter()
1005 .map(|f| f.is_nullable())
1006 .collect::<Vec<_>>();
1007 let mut union_field_meta = base_schema
1008 .fields()
1009 .iter()
1010 .map(|f| f.metadata().clone())
1011 .collect::<Vec<_>>();
1012
1013 let mut metadata = base_schema.metadata().clone();
1014
1015 for (i, plan) in inputs.iter().enumerate() {
1016 let plan_schema = plan.schema();
1017 metadata.extend(plan_schema.metadata().clone());
1018
1019 if plan_schema.fields().len() != base_schema.fields().len() {
1020 return plan_err!(
1021 "Union schemas have different number of fields: \
1022 query 1 has {} fields whereas query {} has {} fields",
1023 base_schema.fields().len(),
1024 i + 1,
1025 plan_schema.fields().len()
1026 );
1027 }
1028
1029 for (union_datatype, union_nullable, union_field_map, plan_field) in izip!(
1031 union_datatypes.iter_mut(),
1032 union_nullabilities.iter_mut(),
1033 union_field_meta.iter_mut(),
1034 plan_schema.fields().iter()
1035 ) {
1036 let coerced_type =
1037 comparison_coercion(union_datatype, plan_field.data_type()).ok_or_else(
1038 || {
1039 plan_datafusion_err!(
1040 "Incompatible inputs for Union: Previous inputs were \
1041 of type {}, but got incompatible type {} on column '{}'",
1042 union_datatype,
1043 plan_field.data_type(),
1044 plan_field.name()
1045 )
1046 },
1047 )?;
1048
1049 *union_datatype = coerced_type;
1050 *union_nullable = *union_nullable || plan_field.is_nullable();
1051 union_field_map.extend(plan_field.metadata().clone());
1052 }
1053 }
1054 let union_qualified_fields = izip!(
1055 base_schema.fields(),
1056 union_datatypes.into_iter(),
1057 union_nullabilities,
1058 union_field_meta.into_iter()
1059 )
1060 .map(|(field, datatype, nullable, metadata)| {
1061 let mut field = Field::new(field.name().clone(), datatype, nullable);
1062 field.set_metadata(metadata);
1063 (None, field.into())
1064 })
1065 .collect::<Vec<_>>();
1066
1067 DFSchema::new_with_metadata(union_qualified_fields, metadata)
1068}
1069
1070fn project_with_column_index(
1072 expr: Vec<Expr>,
1073 input: Arc<LogicalPlan>,
1074 schema: DFSchemaRef,
1075) -> Result<LogicalPlan> {
1076 let alias_expr = expr
1077 .into_iter()
1078 .enumerate()
1079 .map(|(i, e)| match e {
1080 Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => {
1081 Ok(e.unalias().alias(schema.field(i).name()))
1082 }
1083 Expr::Column(Column {
1084 relation: _,
1085 ref name,
1086 spans: _,
1087 }) if name != schema.field(i).name() => Ok(e.alias(schema.field(i).name())),
1088 Expr::Alias { .. } | Expr::Column { .. } => Ok(e),
1089 #[expect(deprecated)]
1090 Expr::Wildcard { .. } => {
1091 plan_err!("Wildcard should be expanded before type coercion")
1092 }
1093 _ => Ok(e.alias(schema.field(i).name())),
1094 })
1095 .collect::<Result<Vec<_>>>()?;
1096
1097 Projection::try_new_with_schema(alias_expr, input, schema)
1098 .map(LogicalPlan::Projection)
1099}
1100
1101#[cfg(test)]
1102mod test {
1103 use std::any::Any;
1104 use std::sync::Arc;
1105
1106 use arrow::datatypes::DataType::Utf8;
1107 use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder, TimeUnit};
1108 use insta::assert_snapshot;
1109
1110 use crate::analyzer::type_coercion::{
1111 coerce_case_expression, TypeCoercion, TypeCoercionRewriter,
1112 };
1113 use crate::analyzer::Analyzer;
1114 use crate::assert_analyzed_plan_with_config_eq_snapshot;
1115 use datafusion_common::config::ConfigOptions;
1116 use datafusion_common::tree_node::{TransformedResult, TreeNode};
1117 use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue, Spans};
1118 use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction};
1119 use datafusion_expr::logical_plan::{EmptyRelation, Projection, Sort};
1120 use datafusion_expr::test::function_stub::avg_udaf;
1121 use datafusion_expr::{
1122 cast, col, create_udaf, is_true, lit, AccumulatorFactoryFunction, AggregateUDF,
1123 BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Filter, LogicalPlan,
1124 Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
1125 SimpleAggregateUDF, Subquery, Union, Volatility,
1126 };
1127 use datafusion_functions_aggregate::average::AvgAccumulator;
1128 use datafusion_sql::TableReference;
1129
1130 fn empty() -> Arc<LogicalPlan> {
1131 Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1132 produce_one_row: false,
1133 schema: Arc::new(DFSchema::empty()),
1134 }))
1135 }
1136
1137 fn empty_with_type(data_type: DataType) -> Arc<LogicalPlan> {
1138 Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1139 produce_one_row: false,
1140 schema: Arc::new(
1141 DFSchema::from_unqualified_fields(
1142 vec![Field::new("a", data_type, true)].into(),
1143 std::collections::HashMap::new(),
1144 )
1145 .unwrap(),
1146 ),
1147 }))
1148 }
1149
1150 macro_rules! assert_analyzed_plan_eq {
1151 (
1152 $plan: expr,
1153 @ $expected: literal $(,)?
1154 ) => {{
1155 let options = ConfigOptions::default();
1156 let rule = Arc::new(TypeCoercion::new());
1157 assert_analyzed_plan_with_config_eq_snapshot!(
1158 options,
1159 rule,
1160 $plan,
1161 @ $expected,
1162 )
1163 }};
1164 }
1165
1166 macro_rules! coerce_on_output_if_viewtype {
1167 (
1168 $is_viewtype: expr,
1169 $plan: expr,
1170 @ $expected: literal $(,)?
1171 ) => {{
1172 let mut options = ConfigOptions::default();
1173 if $is_viewtype {options.optimizer.expand_views_at_output = true;}
1175 let rule = Arc::new(TypeCoercion::new());
1176
1177 assert_analyzed_plan_with_config_eq_snapshot!(
1178 options,
1179 rule,
1180 $plan,
1181 @ $expected,
1182 )
1183 }};
1184 }
1185
1186 fn assert_type_coercion_error(
1187 plan: LogicalPlan,
1188 expected_substr: &str,
1189 ) -> Result<()> {
1190 let options = ConfigOptions::default();
1191 let analyzer = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]);
1192
1193 match analyzer.execute_and_check(plan, &options, |_, _| {}) {
1194 Ok(succeeded_plan) => {
1195 panic!(
1196 "Expected a type coercion error, but analysis succeeded: \n{succeeded_plan:#?}"
1197 );
1198 }
1199 Err(e) => {
1200 let msg = e.to_string();
1201 assert!(
1202 msg.contains(expected_substr),
1203 "Error did not contain expected substring.\n expected to find: `{expected_substr}`\n actual error: `{msg}`"
1204 );
1205 }
1206 }
1207
1208 Ok(())
1209 }
1210
1211 #[test]
1212 fn simple_case() -> Result<()> {
1213 let expr = col("a").lt(lit(2_u32));
1214 let empty = empty_with_type(DataType::Float64);
1215 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1216
1217 assert_analyzed_plan_eq!(
1218 plan,
1219 @r"
1220 Projection: a < CAST(UInt32(2) AS Float64)
1221 EmptyRelation: rows=0
1222 "
1223 )
1224 }
1225
1226 #[test]
1227 fn test_coerce_union() -> Result<()> {
1228 let left_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1229 produce_one_row: false,
1230 schema: Arc::new(
1231 DFSchema::try_from_qualified_schema(
1232 TableReference::full("datafusion", "test", "foo"),
1233 &Schema::new(vec![Field::new("a", DataType::Int32, false)]),
1234 )
1235 .unwrap(),
1236 ),
1237 }));
1238 let right_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1239 produce_one_row: false,
1240 schema: Arc::new(
1241 DFSchema::try_from_qualified_schema(
1242 TableReference::full("datafusion", "test", "foo"),
1243 &Schema::new(vec![Field::new("a", DataType::Int64, false)]),
1244 )
1245 .unwrap(),
1246 ),
1247 }));
1248 let union = LogicalPlan::Union(Union::try_new_with_loose_types(vec![
1249 left_plan, right_plan,
1250 ])?);
1251 let analyzed_union = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())])
1252 .execute_and_check(union, &ConfigOptions::default(), |_, _| {})?;
1253 let top_level_plan = LogicalPlan::Projection(Projection::try_new(
1254 vec![col("a")],
1255 Arc::new(analyzed_union),
1256 )?);
1257
1258 assert_analyzed_plan_eq!(
1259 top_level_plan,
1260 @r"
1261 Projection: a
1262 Union
1263 Projection: CAST(datafusion.test.foo.a AS Int64) AS a
1264 EmptyRelation: rows=0
1265 EmptyRelation: rows=0
1266 "
1267 )
1268 }
1269
1270 #[test]
1271 fn coerce_utf8view_output() -> Result<()> {
1272 let expr = col("a");
1275 let empty = empty_with_type(DataType::Utf8View);
1276 let plan = LogicalPlan::Projection(Projection::try_new(
1277 vec![expr.clone()],
1278 Arc::clone(&empty),
1279 )?);
1280
1281 coerce_on_output_if_viewtype!(
1283 false,
1284 plan.clone(),
1285 @r"
1286 Projection: a
1287 EmptyRelation: rows=0
1288 "
1289 )?;
1290
1291 coerce_on_output_if_viewtype!(
1293 true,
1294 plan.clone(),
1295 @r"
1296 Projection: CAST(a AS LargeUtf8)
1297 EmptyRelation: rows=0
1298 "
1299 )?;
1300
1301 let bool_expr = col("a").lt(lit("foo"));
1304 let bool_plan = LogicalPlan::Projection(Projection::try_new(
1305 vec![bool_expr],
1306 Arc::clone(&empty),
1307 )?);
1308 coerce_on_output_if_viewtype!(
1310 false,
1311 bool_plan.clone(),
1312 @r#"
1313 Projection: a < CAST(Utf8("foo") AS Utf8View)
1314 EmptyRelation: rows=0
1315 "#
1316 )?;
1317
1318 coerce_on_output_if_viewtype!(
1319 false,
1320 plan.clone(),
1321 @r"
1322 Projection: a
1323 EmptyRelation: rows=0
1324 "
1325 )?;
1326
1327 coerce_on_output_if_viewtype!(
1329 true,
1330 plan.clone(),
1331 @r"
1332 Projection: CAST(a AS LargeUtf8)
1333 EmptyRelation: rows=0
1334 "
1335 )?;
1336
1337 let sort_expr = expr.sort(true, true);
1340 let sort_plan = LogicalPlan::Sort(Sort {
1341 expr: vec![sort_expr],
1342 input: Arc::new(plan),
1343 fetch: None,
1344 });
1345
1346 coerce_on_output_if_viewtype!(
1348 false,
1349 sort_plan.clone(),
1350 @r"
1351 Sort: a ASC NULLS FIRST
1352 Projection: a
1353 EmptyRelation: rows=0
1354 "
1355 )?;
1356
1357 coerce_on_output_if_viewtype!(
1359 true,
1360 sort_plan.clone(),
1361 @r"
1362 Projection: CAST(a AS LargeUtf8)
1363 Sort: a ASC NULLS FIRST
1364 Projection: a
1365 EmptyRelation: rows=0
1366 "
1367 )?;
1368
1369 let plan = LogicalPlan::Projection(Projection::try_new(
1372 vec![col("a")],
1373 Arc::new(sort_plan),
1374 )?);
1375 coerce_on_output_if_viewtype!(
1377 false,
1378 plan.clone(),
1379 @r"
1380 Projection: a
1381 Sort: a ASC NULLS FIRST
1382 Projection: a
1383 EmptyRelation: rows=0
1384 "
1385 )?;
1386 coerce_on_output_if_viewtype!(
1388 true,
1389 plan.clone(),
1390 @r"
1391 Projection: CAST(a AS LargeUtf8)
1392 Sort: a ASC NULLS FIRST
1393 Projection: a
1394 EmptyRelation: rows=0
1395 "
1396 )?;
1397
1398 Ok(())
1399 }
1400
1401 #[test]
1402 fn coerce_binaryview_output() -> Result<()> {
1403 let expr = col("a");
1406 let empty = empty_with_type(DataType::BinaryView);
1407 let plan = LogicalPlan::Projection(Projection::try_new(
1408 vec![expr.clone()],
1409 Arc::clone(&empty),
1410 )?);
1411
1412 coerce_on_output_if_viewtype!(
1414 false,
1415 plan.clone(),
1416 @r"
1417 Projection: a
1418 EmptyRelation: rows=0
1419 "
1420 )?;
1421
1422 coerce_on_output_if_viewtype!(
1424 true,
1425 plan.clone(),
1426 @r"
1427 Projection: CAST(a AS LargeBinary)
1428 EmptyRelation: rows=0
1429 "
1430 )?;
1431
1432 let bool_expr = col("a").lt(lit(vec![8, 1, 8, 1]));
1435 let bool_plan = LogicalPlan::Projection(Projection::try_new(
1436 vec![bool_expr],
1437 Arc::clone(&empty),
1438 )?);
1439
1440 coerce_on_output_if_viewtype!(
1442 false,
1443 bool_plan.clone(),
1444 @r#"
1445 Projection: a < CAST(Binary("8,1,8,1") AS BinaryView)
1446 EmptyRelation: rows=0
1447 "#
1448 )?;
1449
1450 coerce_on_output_if_viewtype!(
1452 true,
1453 bool_plan.clone(),
1454 @r#"
1455 Projection: a < CAST(Binary("8,1,8,1") AS BinaryView)
1456 EmptyRelation: rows=0
1457 "#
1458 )?;
1459
1460 let sort_expr = expr.sort(true, true);
1463 let sort_plan = LogicalPlan::Sort(Sort {
1464 expr: vec![sort_expr],
1465 input: Arc::new(plan),
1466 fetch: None,
1467 });
1468
1469 coerce_on_output_if_viewtype!(
1471 false,
1472 sort_plan.clone(),
1473 @r"
1474 Sort: a ASC NULLS FIRST
1475 Projection: a
1476 EmptyRelation: rows=0
1477 "
1478 )?;
1479 coerce_on_output_if_viewtype!(
1481 true,
1482 sort_plan.clone(),
1483 @r"
1484 Projection: CAST(a AS LargeBinary)
1485 Sort: a ASC NULLS FIRST
1486 Projection: a
1487 EmptyRelation: rows=0
1488 "
1489 )?;
1490
1491 let plan = LogicalPlan::Projection(Projection::try_new(
1494 vec![col("a")],
1495 Arc::new(sort_plan),
1496 )?);
1497
1498 coerce_on_output_if_viewtype!(
1500 false,
1501 plan.clone(),
1502 @r"
1503 Projection: a
1504 Sort: a ASC NULLS FIRST
1505 Projection: a
1506 EmptyRelation: rows=0
1507 "
1508 )?;
1509
1510 coerce_on_output_if_viewtype!(
1512 true,
1513 plan.clone(),
1514 @r"
1515 Projection: CAST(a AS LargeBinary)
1516 Sort: a ASC NULLS FIRST
1517 Projection: a
1518 EmptyRelation: rows=0
1519 "
1520 )?;
1521
1522 Ok(())
1523 }
1524
1525 #[test]
1526 fn nested_case() -> Result<()> {
1527 let expr = col("a").lt(lit(2_u32));
1528 let empty = empty_with_type(DataType::Float64);
1529
1530 let plan = LogicalPlan::Projection(Projection::try_new(
1531 vec![expr.clone().or(expr)],
1532 empty,
1533 )?);
1534
1535 assert_analyzed_plan_eq!(
1536 plan,
1537 @r"
1538 Projection: a < CAST(UInt32(2) AS Float64) OR a < CAST(UInt32(2) AS Float64)
1539 EmptyRelation: rows=0
1540 "
1541 )
1542 }
1543
1544 #[derive(Debug, PartialEq, Eq, Hash)]
1545 struct TestScalarUDF {
1546 signature: Signature,
1547 }
1548
1549 impl ScalarUDFImpl for TestScalarUDF {
1550 fn as_any(&self) -> &dyn Any {
1551 self
1552 }
1553
1554 fn name(&self) -> &str {
1555 "TestScalarUDF"
1556 }
1557
1558 fn signature(&self) -> &Signature {
1559 &self.signature
1560 }
1561
1562 fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1563 Ok(Utf8)
1564 }
1565
1566 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1567 Ok(ColumnarValue::Scalar(ScalarValue::from("a")))
1568 }
1569 }
1570
1571 #[test]
1572 fn scalar_udf() -> Result<()> {
1573 let empty = empty();
1574
1575 let udf = ScalarUDF::from(TestScalarUDF {
1576 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1577 })
1578 .call(vec![lit(123_i32)]);
1579 let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?);
1580
1581 assert_analyzed_plan_eq!(
1582 plan,
1583 @r"
1584 Projection: TestScalarUDF(CAST(Int32(123) AS Float32))
1585 EmptyRelation: rows=0
1586 "
1587 )
1588 }
1589
1590 #[test]
1591 fn scalar_udf_invalid_input() -> Result<()> {
1592 let empty = empty();
1593 let udf = ScalarUDF::from(TestScalarUDF {
1594 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1595 })
1596 .call(vec![lit("Apple")]);
1597 Projection::try_new(vec![udf], empty)
1598 .expect_err("Expected an error due to incorrect function input");
1599
1600 Ok(())
1601 }
1602
1603 #[test]
1604 fn scalar_function() -> Result<()> {
1605 let empty = empty();
1607 let lit_expr = lit(10i64);
1608 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
1609 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1610 });
1611 let scalar_function_expr =
1612 Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![lit_expr]));
1613 let plan = LogicalPlan::Projection(Projection::try_new(
1614 vec![scalar_function_expr],
1615 empty,
1616 )?);
1617
1618 assert_analyzed_plan_eq!(
1619 plan,
1620 @r"
1621 Projection: TestScalarUDF(CAST(Int64(10) AS Float32))
1622 EmptyRelation: rows=0
1623 "
1624 )
1625 }
1626
1627 #[test]
1628 fn agg_udaf() -> Result<()> {
1629 let empty = empty();
1630 let my_avg = create_udaf(
1631 "MY_AVG",
1632 vec![DataType::Float64],
1633 Arc::new(DataType::Float64),
1634 Volatility::Immutable,
1635 Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
1636 Arc::new(vec![DataType::UInt64, DataType::Float64]),
1637 );
1638 let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1639 Arc::new(my_avg),
1640 vec![lit(10i64)],
1641 false,
1642 None,
1643 vec![],
1644 None,
1645 ));
1646 let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?);
1647
1648 assert_analyzed_plan_eq!(
1649 plan,
1650 @r"
1651 Projection: MY_AVG(CAST(Int64(10) AS Float64))
1652 EmptyRelation: rows=0
1653 "
1654 )
1655 }
1656
1657 #[test]
1658 fn agg_udaf_invalid_input() -> Result<()> {
1659 let empty = empty();
1660 let return_type = DataType::Float64;
1661 let accumulator: AccumulatorFactoryFunction =
1662 Arc::new(|_| Ok(Box::<AvgAccumulator>::default()));
1663 let my_avg = AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
1664 "MY_AVG",
1665 Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
1666 return_type,
1667 accumulator,
1668 vec![
1669 Field::new("count", DataType::UInt64, true).into(),
1670 Field::new("avg", DataType::Float64, true).into(),
1671 ],
1672 ));
1673 let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1674 Arc::new(my_avg),
1675 vec![lit("10")],
1676 false,
1677 None,
1678 vec![],
1679 None,
1680 ));
1681
1682 let err = Projection::try_new(vec![udaf], empty).err().unwrap();
1683 assert!(
1684 err.strip_backtrace().starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'MY_AVG' function: coercion from [Utf8] to the signature Uniform(1, [Float64]) failed")
1685 );
1686 Ok(())
1687 }
1688
1689 #[test]
1690 fn agg_function_case() -> Result<()> {
1691 let empty = empty();
1692 let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1693 avg_udaf(),
1694 vec![lit(12f64)],
1695 false,
1696 None,
1697 vec![],
1698 None,
1699 ));
1700 let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1701
1702 assert_analyzed_plan_eq!(
1703 plan,
1704 @r"
1705 Projection: avg(Float64(12))
1706 EmptyRelation: rows=0
1707 "
1708 )?;
1709
1710 let empty = empty_with_type(DataType::Int32);
1711 let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1712 avg_udaf(),
1713 vec![cast(col("a"), DataType::Float64)],
1714 false,
1715 None,
1716 vec![],
1717 None,
1718 ));
1719 let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1720
1721 assert_analyzed_plan_eq!(
1722 plan,
1723 @r"
1724 Projection: avg(CAST(a AS Float64))
1725 EmptyRelation: rows=0
1726 "
1727 )
1728 }
1729
1730 #[test]
1731 fn agg_function_invalid_input_avg() -> Result<()> {
1732 let empty = empty();
1733 let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1734 avg_udaf(),
1735 vec![lit("1")],
1736 false,
1737 None,
1738 vec![],
1739 None,
1740 ));
1741 let err = Projection::try_new(vec![agg_expr], empty)
1742 .err()
1743 .unwrap()
1744 .strip_backtrace();
1745 assert!(err.starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'avg' function: coercion from [Utf8] to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed"));
1746 Ok(())
1747 }
1748
1749 #[test]
1750 fn binary_op_date32_op_interval() -> Result<()> {
1751 let expr = cast(lit("1998-03-18"), DataType::Date32)
1753 + lit(ScalarValue::new_interval_dt(123, 456));
1754 let empty = empty();
1755 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1756
1757 assert_analyzed_plan_eq!(
1758 plan,
1759 @r#"
1760 Projection: CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("IntervalDayTime { days: 123, milliseconds: 456 }")
1761 EmptyRelation: rows=0
1762 "#
1763 )
1764 }
1765
1766 #[test]
1767 fn inlist_case() -> Result<()> {
1768 let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1770 let empty = empty_with_type(DataType::Int64);
1771 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1772 assert_analyzed_plan_eq!(
1773 plan,
1774 @r"
1775 Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)])
1776 EmptyRelation: rows=0
1777 ")?;
1778
1779 let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1781 let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1782 produce_one_row: false,
1783 schema: Arc::new(DFSchema::from_unqualified_fields(
1784 vec![Field::new("a", DataType::Decimal128(12, 4), true)].into(),
1785 std::collections::HashMap::new(),
1786 )?),
1787 }));
1788 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1789 assert_analyzed_plan_eq!(
1790 plan,
1791 @r"
1792 Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))])
1793 EmptyRelation: rows=0
1794 ")
1795 }
1796
1797 #[test]
1798 fn between_case() -> Result<()> {
1799 let expr = col("a").between(
1800 lit("2002-05-08"),
1801 cast(lit("2002-05-08"), DataType::Date32)
1803 + lit(ScalarValue::new_interval_ym(0, 1)),
1804 );
1805 let empty = empty_with_type(Utf8);
1806 let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1807
1808 assert_analyzed_plan_eq!(
1809 plan,
1810 @r#"
1811 Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) AND CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1")
1812 EmptyRelation: rows=0
1813 "#
1814 )
1815 }
1816
1817 #[test]
1818 fn between_infer_cheap_type() -> Result<()> {
1819 let expr = col("a").between(
1820 cast(lit("2002-05-08"), DataType::Date32)
1822 + lit(ScalarValue::new_interval_ym(0, 1)),
1823 lit("2002-12-08"),
1824 );
1825 let empty = empty_with_type(Utf8);
1826 let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1827
1828 assert_analyzed_plan_eq!(
1830 plan,
1831 @r#"
1832 Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1") AND CAST(Utf8("2002-12-08") AS Date32)
1833 EmptyRelation: rows=0
1834 "#
1835 )
1836 }
1837
1838 #[test]
1839 fn between_null() -> Result<()> {
1840 let expr = lit(ScalarValue::Null).between(lit(ScalarValue::Null), lit(2i64));
1841 let empty = empty();
1842 let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1843
1844 assert_analyzed_plan_eq!(
1845 plan,
1846 @r"
1847 Filter: CAST(NULL AS Int64) BETWEEN CAST(NULL AS Int64) AND Int64(2)
1848 EmptyRelation: rows=0
1849 "
1850 )
1851 }
1852
1853 #[test]
1854 fn is_bool_for_type_coercion() -> Result<()> {
1855 let expr = col("a").is_true();
1857 let empty = empty_with_type(DataType::Boolean);
1858 let plan =
1859 LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
1860
1861 assert_analyzed_plan_eq!(
1862 plan,
1863 @r"
1864 Projection: a IS TRUE
1865 EmptyRelation: rows=0
1866 "
1867 )?;
1868
1869 let empty = empty_with_type(DataType::Int64);
1870 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1871 assert_type_coercion_error(
1872 plan,
1873 "Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean"
1874 )?;
1875
1876 let expr = col("a").is_not_true();
1878 let empty = empty_with_type(DataType::Boolean);
1879 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1880
1881 assert_analyzed_plan_eq!(
1882 plan,
1883 @r"
1884 Projection: a IS NOT TRUE
1885 EmptyRelation: rows=0
1886 "
1887 )?;
1888
1889 let expr = col("a").is_false();
1891 let empty = empty_with_type(DataType::Boolean);
1892 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1893
1894 assert_analyzed_plan_eq!(
1895 plan,
1896 @r"
1897 Projection: a IS FALSE
1898 EmptyRelation: rows=0
1899 "
1900 )?;
1901
1902 let expr = col("a").is_not_false();
1904 let empty = empty_with_type(DataType::Boolean);
1905 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1906
1907 assert_analyzed_plan_eq!(
1908 plan,
1909 @r"
1910 Projection: a IS NOT FALSE
1911 EmptyRelation: rows=0
1912 "
1913 )
1914 }
1915
1916 #[test]
1917 fn like_for_type_coercion() -> Result<()> {
1918 let expr = Box::new(col("a"));
1920 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1921 let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1922 let empty = empty_with_type(Utf8);
1923 let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1924
1925 assert_analyzed_plan_eq!(
1926 plan,
1927 @r#"
1928 Projection: a LIKE Utf8("abc")
1929 EmptyRelation: rows=0
1930 "#
1931 )?;
1932
1933 let expr = Box::new(col("a"));
1934 let pattern = Box::new(lit(ScalarValue::Null));
1935 let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1936 let empty = empty_with_type(Utf8);
1937 let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1938
1939 assert_analyzed_plan_eq!(
1940 plan,
1941 @r"
1942 Projection: a LIKE CAST(NULL AS Utf8)
1943 EmptyRelation: rows=0
1944 "
1945 )?;
1946
1947 let expr = Box::new(col("a"));
1948 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1949 let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1950 let empty = empty_with_type(DataType::Int64);
1951 let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1952 assert_type_coercion_error(
1953 plan,
1954 "There isn't a common type to coerce Int64 and Utf8 in LIKE expression",
1955 )?;
1956
1957 let expr = Box::new(col("a"));
1959 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1960 let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
1961 let empty = empty_with_type(Utf8);
1962 let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
1963
1964 assert_analyzed_plan_eq!(
1965 plan,
1966 @r#"
1967 Projection: a ILIKE Utf8("abc")
1968 EmptyRelation: rows=0
1969 "#
1970 )?;
1971
1972 let expr = Box::new(col("a"));
1973 let pattern = Box::new(lit(ScalarValue::Null));
1974 let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
1975 let empty = empty_with_type(Utf8);
1976 let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
1977
1978 assert_analyzed_plan_eq!(
1979 plan,
1980 @r"
1981 Projection: a ILIKE CAST(NULL AS Utf8)
1982 EmptyRelation: rows=0
1983 "
1984 )?;
1985
1986 let expr = Box::new(col("a"));
1987 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1988 let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
1989 let empty = empty_with_type(DataType::Int64);
1990 let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
1991 assert_type_coercion_error(
1992 plan,
1993 "There isn't a common type to coerce Int64 and Utf8 in ILIKE expression",
1994 )?;
1995
1996 Ok(())
1997 }
1998
1999 #[test]
2000 fn unknown_for_type_coercion() -> Result<()> {
2001 let expr = col("a").is_unknown();
2003 let empty = empty_with_type(DataType::Boolean);
2004 let plan =
2005 LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
2006
2007 assert_analyzed_plan_eq!(
2008 plan,
2009 @r"
2010 Projection: a IS UNKNOWN
2011 EmptyRelation: rows=0
2012 "
2013 )?;
2014
2015 let empty = empty_with_type(Utf8);
2016 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2017 assert_type_coercion_error(
2018 plan,
2019 "Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean"
2020 )?;
2021
2022 let expr = col("a").is_not_unknown();
2024 let empty = empty_with_type(DataType::Boolean);
2025 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2026
2027 assert_analyzed_plan_eq!(
2028 plan,
2029 @r"
2030 Projection: a IS NOT UNKNOWN
2031 EmptyRelation: rows=0
2032 "
2033 )
2034 }
2035
2036 #[test]
2037 fn concat_for_type_coercion() -> Result<()> {
2038 let empty = empty_with_type(Utf8);
2039 let args = [col("a"), lit("b"), lit(true), lit(false), lit(13)];
2040
2041 let expr = ScalarUDF::new_from_impl(TestScalarUDF {
2043 signature: Signature::variadic(vec![Utf8], Volatility::Immutable),
2044 })
2045 .call(args.to_vec());
2046 let plan =
2047 LogicalPlan::Projection(Projection::try_new(vec![expr], Arc::clone(&empty))?);
2048 assert_analyzed_plan_eq!(
2049 plan,
2050 @r#"
2051 Projection: TestScalarUDF(a, Utf8("b"), CAST(Boolean(true) AS Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS Utf8))
2052 EmptyRelation: rows=0
2053 "#
2054 )
2055 }
2056
2057 #[test]
2058 fn test_type_coercion_rewrite() -> Result<()> {
2059 let schema = Arc::new(DFSchema::from_unqualified_fields(
2061 vec![Field::new("a", DataType::Int64, true)].into(),
2062 std::collections::HashMap::new(),
2063 )?);
2064 let mut rewriter = TypeCoercionRewriter { schema: &schema };
2065 let expr = is_true(lit(12i32).gt(lit(13i64)));
2066 let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64)));
2067 let result = expr.rewrite(&mut rewriter).data()?;
2068 assert_eq!(expected, result);
2069
2070 let schema = Arc::new(DFSchema::from_unqualified_fields(
2072 vec![Field::new("a", DataType::Int64, true)].into(),
2073 std::collections::HashMap::new(),
2074 )?);
2075 let mut rewriter = TypeCoercionRewriter { schema: &schema };
2076 let expr = is_true(lit(12i32).eq(lit(13i64)));
2077 let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64)));
2078 let result = expr.rewrite(&mut rewriter).data()?;
2079 assert_eq!(expected, result);
2080
2081 let schema = Arc::new(DFSchema::from_unqualified_fields(
2083 vec![Field::new("a", DataType::Int64, true)].into(),
2084 std::collections::HashMap::new(),
2085 )?);
2086 let mut rewriter = TypeCoercionRewriter { schema: &schema };
2087 let expr = is_true(lit(12i32).lt(lit(13i64)));
2088 let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64)));
2089 let result = expr.rewrite(&mut rewriter).data()?;
2090 assert_eq!(expected, result);
2091
2092 Ok(())
2093 }
2094
2095 #[test]
2096 fn binary_op_date32_eq_ts() -> Result<()> {
2097 let expr = cast(
2098 lit("1998-03-18"),
2099 DataType::Timestamp(TimeUnit::Nanosecond, None),
2100 )
2101 .eq(cast(lit("1998-03-18"), DataType::Date32));
2102 let empty = empty();
2103 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2104
2105 assert_analyzed_plan_eq!(
2106 plan,
2107 @r#"
2108 Projection: CAST(Utf8("1998-03-18") AS Timestamp(Nanosecond, None)) = CAST(CAST(Utf8("1998-03-18") AS Date32) AS Timestamp(Nanosecond, None))
2109 EmptyRelation: rows=0
2110 "#
2111 )
2112 }
2113
2114 fn cast_if_not_same_type(
2115 expr: Box<Expr>,
2116 data_type: &DataType,
2117 schema: &DFSchemaRef,
2118 ) -> Box<Expr> {
2119 if &expr.get_type(schema).unwrap() != data_type {
2120 Box::new(cast(*expr, data_type.clone()))
2121 } else {
2122 expr
2123 }
2124 }
2125
2126 fn cast_helper(
2127 case: Case,
2128 case_when_type: &DataType,
2129 then_else_type: &DataType,
2130 schema: &DFSchemaRef,
2131 ) -> Case {
2132 let expr = case
2133 .expr
2134 .map(|e| cast_if_not_same_type(e, case_when_type, schema));
2135 let when_then_expr = case
2136 .when_then_expr
2137 .into_iter()
2138 .map(|(when, then)| {
2139 (
2140 cast_if_not_same_type(when, case_when_type, schema),
2141 cast_if_not_same_type(then, then_else_type, schema),
2142 )
2143 })
2144 .collect::<Vec<_>>();
2145 let else_expr = case
2146 .else_expr
2147 .map(|e| cast_if_not_same_type(e, then_else_type, schema));
2148
2149 Case {
2150 expr,
2151 when_then_expr,
2152 else_expr,
2153 }
2154 }
2155
2156 #[test]
2157 fn test_case_expression_coercion() -> Result<()> {
2158 let schema = Arc::new(DFSchema::from_unqualified_fields(
2159 vec![
2160 Field::new("boolean", DataType::Boolean, true),
2161 Field::new("integer", DataType::Int32, true),
2162 Field::new("float", DataType::Float32, true),
2163 Field::new(
2164 "timestamp",
2165 DataType::Timestamp(TimeUnit::Nanosecond, None),
2166 true,
2167 ),
2168 Field::new("date", DataType::Date32, true),
2169 Field::new(
2170 "interval",
2171 DataType::Interval(arrow::datatypes::IntervalUnit::MonthDayNano),
2172 true,
2173 ),
2174 Field::new("binary", DataType::Binary, true),
2175 Field::new("string", Utf8, true),
2176 Field::new("decimal", DataType::Decimal128(10, 10), true),
2177 ]
2178 .into(),
2179 std::collections::HashMap::new(),
2180 )?);
2181
2182 let case = Case {
2183 expr: None,
2184 when_then_expr: vec![
2185 (Box::new(col("boolean")), Box::new(col("integer"))),
2186 (Box::new(col("integer")), Box::new(col("float"))),
2187 (Box::new(col("string")), Box::new(col("string"))),
2188 ],
2189 else_expr: None,
2190 };
2191 let case_when_common_type = DataType::Boolean;
2192 let then_else_common_type = Utf8;
2193 let expected = cast_helper(
2194 case.clone(),
2195 &case_when_common_type,
2196 &then_else_common_type,
2197 &schema,
2198 );
2199 let actual = coerce_case_expression(case, &schema)?;
2200 assert_eq!(expected, actual);
2201
2202 let case = Case {
2203 expr: Some(Box::new(col("string"))),
2204 when_then_expr: vec![
2205 (Box::new(col("float")), Box::new(col("integer"))),
2206 (Box::new(col("integer")), Box::new(col("float"))),
2207 (Box::new(col("string")), Box::new(col("string"))),
2208 ],
2209 else_expr: Some(Box::new(col("string"))),
2210 };
2211 let case_when_common_type = Utf8;
2212 let then_else_common_type = Utf8;
2213 let expected = cast_helper(
2214 case.clone(),
2215 &case_when_common_type,
2216 &then_else_common_type,
2217 &schema,
2218 );
2219 let actual = coerce_case_expression(case, &schema)?;
2220 assert_eq!(expected, actual);
2221
2222 let case = Case {
2223 expr: Some(Box::new(col("interval"))),
2224 when_then_expr: vec![
2225 (Box::new(col("float")), Box::new(col("integer"))),
2226 (Box::new(col("binary")), Box::new(col("float"))),
2227 (Box::new(col("string")), Box::new(col("string"))),
2228 ],
2229 else_expr: Some(Box::new(col("string"))),
2230 };
2231 let err = coerce_case_expression(case, &schema).unwrap_err();
2232 assert_snapshot!(
2233 err.strip_backtrace(),
2234 @"Error during planning: Failed to coerce case (Interval(MonthDayNano)) and when ([Float32, Binary, Utf8]) to common types in CASE WHEN expression"
2235 );
2236
2237 let case = Case {
2238 expr: Some(Box::new(col("string"))),
2239 when_then_expr: vec![
2240 (Box::new(col("float")), Box::new(col("date"))),
2241 (Box::new(col("string")), Box::new(col("float"))),
2242 (Box::new(col("string")), Box::new(col("binary"))),
2243 ],
2244 else_expr: Some(Box::new(col("timestamp"))),
2245 };
2246 let err = coerce_case_expression(case, &schema).unwrap_err();
2247 assert_snapshot!(
2248 err.strip_backtrace(),
2249 @"Error during planning: Failed to coerce then ([Date32, Float32, Binary]) and else (Some(Timestamp(Nanosecond, None))) to common types in CASE WHEN expression"
2250 );
2251
2252 Ok(())
2253 }
2254
2255 macro_rules! test_case_expression {
2256 ($expr:expr, $when_then:expr, $case_when_type:expr, $then_else_type:expr, $schema:expr) => {
2257 let case = Case {
2258 expr: $expr.map(|e| Box::new(col(e))),
2259 when_then_expr: $when_then,
2260 else_expr: None,
2261 };
2262
2263 let expected =
2264 cast_helper(case.clone(), &$case_when_type, &$then_else_type, &$schema);
2265
2266 let actual = coerce_case_expression(case, &$schema)?;
2267 assert_eq!(expected, actual);
2268 };
2269 }
2270
2271 #[test]
2272 fn tes_case_when_list() -> Result<()> {
2273 let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
2274 let schema = Arc::new(DFSchema::from_unqualified_fields(
2275 vec![
2276 Field::new(
2277 "large_list",
2278 DataType::LargeList(Arc::clone(&inner_field)),
2279 true,
2280 ),
2281 Field::new(
2282 "fixed_list",
2283 DataType::FixedSizeList(Arc::clone(&inner_field), 3),
2284 true,
2285 ),
2286 Field::new("list", DataType::List(inner_field), true),
2287 ]
2288 .into(),
2289 std::collections::HashMap::new(),
2290 )?);
2291
2292 test_case_expression!(
2293 Some("list"),
2294 vec![(Box::new(col("large_list")), Box::new(lit("1")))],
2295 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2296 Utf8,
2297 schema
2298 );
2299
2300 test_case_expression!(
2301 Some("large_list"),
2302 vec![(Box::new(col("list")), Box::new(lit("1")))],
2303 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2304 Utf8,
2305 schema
2306 );
2307
2308 test_case_expression!(
2309 Some("list"),
2310 vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
2311 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2312 Utf8,
2313 schema
2314 );
2315
2316 test_case_expression!(
2317 Some("fixed_list"),
2318 vec![(Box::new(col("list")), Box::new(lit("1")))],
2319 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2320 Utf8,
2321 schema
2322 );
2323
2324 test_case_expression!(
2325 Some("fixed_list"),
2326 vec![(Box::new(col("large_list")), Box::new(lit("1")))],
2327 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2328 Utf8,
2329 schema
2330 );
2331
2332 test_case_expression!(
2333 Some("large_list"),
2334 vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
2335 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2336 Utf8,
2337 schema
2338 );
2339 Ok(())
2340 }
2341
2342 #[test]
2343 fn test_then_else_list() -> Result<()> {
2344 let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
2345 let schema = Arc::new(DFSchema::from_unqualified_fields(
2346 vec![
2347 Field::new("boolean", DataType::Boolean, true),
2348 Field::new(
2349 "large_list",
2350 DataType::LargeList(Arc::clone(&inner_field)),
2351 true,
2352 ),
2353 Field::new(
2354 "fixed_list",
2355 DataType::FixedSizeList(Arc::clone(&inner_field), 3),
2356 true,
2357 ),
2358 Field::new("list", DataType::List(inner_field), true),
2359 ]
2360 .into(),
2361 std::collections::HashMap::new(),
2362 )?);
2363
2364 test_case_expression!(
2366 None::<String>,
2367 vec![
2368 (Box::new(col("boolean")), Box::new(col("large_list"))),
2369 (Box::new(col("boolean")), Box::new(col("list")))
2370 ],
2371 DataType::Boolean,
2372 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2373 schema
2374 );
2375
2376 test_case_expression!(
2377 None::<String>,
2378 vec![
2379 (Box::new(col("boolean")), Box::new(col("list"))),
2380 (Box::new(col("boolean")), Box::new(col("large_list")))
2381 ],
2382 DataType::Boolean,
2383 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2384 schema
2385 );
2386
2387 test_case_expression!(
2389 None::<String>,
2390 vec![
2391 (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2392 (Box::new(col("boolean")), Box::new(col("list")))
2393 ],
2394 DataType::Boolean,
2395 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2396 schema
2397 );
2398
2399 test_case_expression!(
2400 None::<String>,
2401 vec![
2402 (Box::new(col("boolean")), Box::new(col("list"))),
2403 (Box::new(col("boolean")), Box::new(col("fixed_list")))
2404 ],
2405 DataType::Boolean,
2406 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2407 schema
2408 );
2409
2410 test_case_expression!(
2412 None::<String>,
2413 vec![
2414 (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2415 (Box::new(col("boolean")), Box::new(col("large_list")))
2416 ],
2417 DataType::Boolean,
2418 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2419 schema
2420 );
2421
2422 test_case_expression!(
2423 None::<String>,
2424 vec![
2425 (Box::new(col("boolean")), Box::new(col("large_list"))),
2426 (Box::new(col("boolean")), Box::new(col("fixed_list")))
2427 ],
2428 DataType::Boolean,
2429 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2430 schema
2431 );
2432 Ok(())
2433 }
2434
2435 #[test]
2436 fn test_map_with_diff_name() -> Result<()> {
2437 let mut builder = SchemaBuilder::new();
2438 builder.push(Field::new("key", Utf8, false));
2439 builder.push(Field::new("value", DataType::Float64, true));
2440 let struct_fields = builder.finish().fields;
2441
2442 let fields =
2443 Field::new("entries", DataType::Struct(struct_fields.clone()), false);
2444 let map_type_entries = DataType::Map(Arc::new(fields), false);
2445
2446 let fields = Field::new("key_value", DataType::Struct(struct_fields), false);
2447 let may_type_custom = DataType::Map(Arc::new(fields), false);
2448
2449 let expr = col("a").eq(cast(col("a"), may_type_custom));
2450 let empty = empty_with_type(map_type_entries);
2451 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2452
2453 assert_analyzed_plan_eq!(
2454 plan,
2455 @r#"
2456 Projection: a = CAST(CAST(a AS Map(Field { name: "key_value", data_type: Struct([Field { name: "key", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "value", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false)) AS Map(Field { name: "entries", data_type: Struct([Field { name: "key", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "value", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false))
2457 EmptyRelation: rows=0
2458 "#
2459 )
2460 }
2461
2462 #[test]
2463 fn interval_plus_timestamp() -> Result<()> {
2464 let expr = Expr::BinaryExpr(BinaryExpr::new(
2466 Box::new(lit(ScalarValue::IntervalYearMonth(Some(12)))),
2467 Operator::Plus,
2468 Box::new(cast(
2469 lit("2000-01-01T00:00:00"),
2470 DataType::Timestamp(TimeUnit::Nanosecond, None),
2471 )),
2472 ));
2473 let empty = empty();
2474 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2475
2476 assert_analyzed_plan_eq!(
2477 plan,
2478 @r#"
2479 Projection: IntervalYearMonth("12") + CAST(Utf8("2000-01-01T00:00:00") AS Timestamp(Nanosecond, None))
2480 EmptyRelation: rows=0
2481 "#
2482 )
2483 }
2484
2485 #[test]
2486 fn timestamp_subtract_timestamp() -> Result<()> {
2487 let expr = Expr::BinaryExpr(BinaryExpr::new(
2488 Box::new(cast(
2489 lit("1998-03-18"),
2490 DataType::Timestamp(TimeUnit::Nanosecond, None),
2491 )),
2492 Operator::Minus,
2493 Box::new(cast(
2494 lit("1998-03-18"),
2495 DataType::Timestamp(TimeUnit::Nanosecond, None),
2496 )),
2497 ));
2498 let empty = empty();
2499 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2500
2501 assert_analyzed_plan_eq!(
2502 plan,
2503 @r#"
2504 Projection: CAST(Utf8("1998-03-18") AS Timestamp(Nanosecond, None)) - CAST(Utf8("1998-03-18") AS Timestamp(Nanosecond, None))
2505 EmptyRelation: rows=0
2506 "#
2507 )
2508 }
2509
2510 #[test]
2511 fn in_subquery_cast_subquery() -> Result<()> {
2512 let empty_int32 = empty_with_type(DataType::Int32);
2513 let empty_int64 = empty_with_type(DataType::Int64);
2514
2515 let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2516 Box::new(col("a")),
2517 Subquery {
2518 subquery: empty_int32,
2519 outer_ref_columns: vec![],
2520 spans: Spans::new(),
2521 },
2522 false,
2523 ));
2524 let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int64)?);
2525 assert_analyzed_plan_eq!(
2528 plan,
2529 @r"
2530 Filter: a IN (<subquery>)
2531 Subquery:
2532 Projection: CAST(a AS Int64)
2533 EmptyRelation: rows=0
2534 EmptyRelation: rows=0
2535 "
2536 )
2537 }
2538
2539 #[test]
2540 fn in_subquery_cast_expr() -> Result<()> {
2541 let empty_int32 = empty_with_type(DataType::Int32);
2542 let empty_int64 = empty_with_type(DataType::Int64);
2543
2544 let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2545 Box::new(col("a")),
2546 Subquery {
2547 subquery: empty_int64,
2548 outer_ref_columns: vec![],
2549 spans: Spans::new(),
2550 },
2551 false,
2552 ));
2553 let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int32)?);
2554
2555 assert_analyzed_plan_eq!(
2557 plan,
2558 @r"
2559 Filter: CAST(a AS Int64) IN (<subquery>)
2560 Subquery:
2561 EmptyRelation: rows=0
2562 EmptyRelation: rows=0
2563 "
2564 )
2565 }
2566
2567 #[test]
2568 fn in_subquery_cast_all() -> Result<()> {
2569 let empty_inside = empty_with_type(DataType::Decimal128(10, 5));
2570 let empty_outside = empty_with_type(DataType::Decimal128(8, 8));
2571
2572 let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2573 Box::new(col("a")),
2574 Subquery {
2575 subquery: empty_inside,
2576 outer_ref_columns: vec![],
2577 spans: Spans::new(),
2578 },
2579 false,
2580 ));
2581 let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_outside)?);
2582
2583 assert_analyzed_plan_eq!(
2585 plan,
2586 @r"
2587 Filter: CAST(a AS Decimal128(13, 8)) IN (<subquery>)
2588 Subquery:
2589 Projection: CAST(a AS Decimal128(13, 8))
2590 EmptyRelation: rows=0
2591 EmptyRelation: rows=0
2592 "
2593 )
2594 }
2595}