1use std::sync::Arc;
21
22use datafusion_expr::binary::BinaryTypeCoercer;
23use itertools::izip;
24
25use arrow::datatypes::{DataType, Field, IntervalUnit, Schema};
26
27use crate::analyzer::AnalyzerRule;
28use crate::utils::NamePreserver;
29use datafusion_common::config::ConfigOptions;
30use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
31use datafusion_common::{
32 exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column,
33 DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference,
34};
35use datafusion_expr::expr::{
36 self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, InList,
37 InSubquery, Like, ScalarFunction, Sort, WindowFunction,
38};
39use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
40use datafusion_expr::expr_schema::cast_subquery;
41use datafusion_expr::logical_plan::Subquery;
42use datafusion_expr::type_coercion::binary::{comparison_coercion, like_coercion};
43use datafusion_expr::type_coercion::functions::{
44 data_types_with_scalar_udf, fields_with_aggregate_udf,
45};
46use datafusion_expr::type_coercion::other::{
47 get_coerce_type_for_case_expression, get_coerce_type_for_list,
48};
49use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_utf8view_or_large_utf8};
50use datafusion_expr::utils::merge_schema;
51use datafusion_expr::{
52 is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not,
53 AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, Join, Limit, LogicalPlan,
54 Operator, Projection, ScalarUDF, Union, WindowFrame, WindowFrameBound,
55 WindowFrameUnits,
56};
57
58#[derive(Default, Debug)]
61pub struct TypeCoercion {}
62
63impl TypeCoercion {
64 pub fn new() -> Self {
65 Self {}
66 }
67}
68
69fn coerce_output(plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
71 if !config.optimizer.expand_views_at_output {
72 return Ok(plan);
73 }
74
75 let outer_refs = plan.expressions();
76 if outer_refs.is_empty() {
77 return Ok(plan);
78 }
79
80 if let Some(dfschema) = transform_schema_to_nonview(plan.schema()) {
81 coerce_plan_expr_for_schema(plan, &dfschema?)
82 } else {
83 Ok(plan)
84 }
85}
86
87impl AnalyzerRule for TypeCoercion {
88 fn name(&self) -> &str {
89 "type_coercion"
90 }
91
92 fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
93 let empty_schema = DFSchema::empty();
94
95 let transformed_plan = plan
97 .transform_up_with_subqueries(|plan| analyze_internal(&empty_schema, plan))?
98 .data;
99
100 coerce_output(transformed_plan, config)
102 }
103}
104
105fn analyze_internal(
109 external_schema: &DFSchema,
110 plan: LogicalPlan,
111) -> Result<Transformed<LogicalPlan>> {
112 let mut schema = merge_schema(&plan.inputs());
115
116 if let LogicalPlan::TableScan(ts) = &plan {
117 let source_schema = DFSchema::try_from_qualified_schema(
118 ts.table_name.clone(),
119 &ts.source.schema(),
120 )?;
121 schema.merge(&source_schema);
122 }
123
124 schema.merge(external_schema);
128
129 let plan = if let LogicalPlan::Filter(mut filter) = plan {
131 filter.predicate = filter.predicate.cast_to(&DataType::Boolean, &schema)?;
132 LogicalPlan::Filter(filter)
133 } else {
134 plan
135 };
136
137 let mut expr_rewrite = TypeCoercionRewriter::new(&schema);
138
139 let name_preserver = NamePreserver::new(&plan);
140 plan.map_expressions(|expr| {
142 let original_name = name_preserver.save(&expr);
143 expr.rewrite(&mut expr_rewrite)
144 .map(|transformed| transformed.update_data(|e| original_name.restore(e)))
145 })?
146 .map_data(|plan| expr_rewrite.coerce_plan(plan))?
148 .map_data(|plan| plan.recompute_schema())
150}
151
152pub struct TypeCoercionRewriter<'a> {
154 pub(crate) schema: &'a DFSchema,
155}
156
157impl<'a> TypeCoercionRewriter<'a> {
158 pub fn new(schema: &'a DFSchema) -> Self {
161 Self { schema }
162 }
163
164 pub fn coerce_plan(&mut self, plan: LogicalPlan) -> Result<LogicalPlan> {
169 match plan {
170 LogicalPlan::Join(join) => self.coerce_join(join),
171 LogicalPlan::Union(union) => Self::coerce_union(union),
172 LogicalPlan::Limit(limit) => Self::coerce_limit(limit),
173 _ => Ok(plan),
174 }
175 }
176
177 pub fn coerce_join(&mut self, mut join: Join) -> Result<LogicalPlan> {
186 join.on = join
187 .on
188 .into_iter()
189 .map(|(lhs, rhs)| {
190 let left_schema = join.left.schema();
193 let right_schema = join.right.schema();
194 let (lhs, rhs) = self.coerce_binary_op(
195 lhs,
196 left_schema,
197 Operator::Eq,
198 rhs,
199 right_schema,
200 )?;
201 Ok((lhs, rhs))
202 })
203 .collect::<Result<Vec<_>>>()?;
204
205 join.filter = join
207 .filter
208 .map(|expr| self.coerce_join_filter(expr))
209 .transpose()?;
210
211 Ok(LogicalPlan::Join(join))
212 }
213
214 pub fn coerce_union(union_plan: Union) -> Result<LogicalPlan> {
217 let union_schema = Arc::new(coerce_union_schema_with_schema(
218 &union_plan.inputs,
219 &union_plan.schema,
220 )?);
221 let new_inputs = union_plan
222 .inputs
223 .into_iter()
224 .map(|p| {
225 let plan =
226 coerce_plan_expr_for_schema(Arc::unwrap_or_clone(p), &union_schema)?;
227 match plan {
228 LogicalPlan::Projection(Projection { expr, input, .. }) => {
229 Ok(Arc::new(project_with_column_index(
230 expr,
231 input,
232 Arc::clone(&union_schema),
233 )?))
234 }
235 other_plan => Ok(Arc::new(other_plan)),
236 }
237 })
238 .collect::<Result<Vec<_>>>()?;
239 Ok(LogicalPlan::Union(Union {
240 inputs: new_inputs,
241 schema: union_schema,
242 }))
243 }
244
245 fn coerce_limit(limit: Limit) -> Result<LogicalPlan> {
247 fn coerce_limit_expr(
248 expr: Expr,
249 schema: &DFSchema,
250 expr_name: &str,
251 ) -> Result<Expr> {
252 let dt = expr.get_type(schema)?;
253 if dt.is_integer() || dt.is_null() {
254 expr.cast_to(&DataType::Int64, schema)
255 } else {
256 plan_err!("Expected {expr_name} to be an integer or null, but got {dt:?}")
257 }
258 }
259
260 let empty_schema = DFSchema::empty();
261 let new_fetch = limit
262 .fetch
263 .map(|expr| coerce_limit_expr(*expr, &empty_schema, "LIMIT"))
264 .transpose()?;
265 let new_skip = limit
266 .skip
267 .map(|expr| coerce_limit_expr(*expr, &empty_schema, "OFFSET"))
268 .transpose()?;
269 Ok(LogicalPlan::Limit(Limit {
270 input: limit.input,
271 fetch: new_fetch.map(Box::new),
272 skip: new_skip.map(Box::new),
273 }))
274 }
275
276 fn coerce_join_filter(&self, expr: Expr) -> Result<Expr> {
277 let expr_type = expr.get_type(self.schema)?;
278 match expr_type {
279 DataType::Boolean => Ok(expr),
280 DataType::Null => expr.cast_to(&DataType::Boolean, self.schema),
281 other => plan_err!("Join condition must be boolean type, but got {other:?}"),
282 }
283 }
284
285 fn coerce_binary_op(
286 &self,
287 left: Expr,
288 left_schema: &DFSchema,
289 op: Operator,
290 right: Expr,
291 right_schema: &DFSchema,
292 ) -> Result<(Expr, Expr)> {
293 let (left_type, right_type) = BinaryTypeCoercer::new(
294 &left.get_type(left_schema)?,
295 &op,
296 &right.get_type(right_schema)?,
297 )
298 .get_input_types()?;
299
300 Ok((
301 left.cast_to(&left_type, left_schema)?,
302 right.cast_to(&right_type, right_schema)?,
303 ))
304 }
305}
306
307impl TreeNodeRewriter for TypeCoercionRewriter<'_> {
308 type Node = Expr;
309
310 fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
311 match expr {
312 Expr::Unnest(_) => not_impl_err!(
313 "Unnest should be rewritten to LogicalPlan::Unnest before type coercion"
314 ),
315 Expr::ScalarSubquery(Subquery {
316 subquery,
317 outer_ref_columns,
318 spans,
319 }) => {
320 let new_plan =
321 analyze_internal(self.schema, Arc::unwrap_or_clone(subquery))?.data;
322 Ok(Transformed::yes(Expr::ScalarSubquery(Subquery {
323 subquery: Arc::new(new_plan),
324 outer_ref_columns,
325 spans,
326 })))
327 }
328 Expr::Exists(Exists { subquery, negated }) => {
329 let new_plan = analyze_internal(
330 self.schema,
331 Arc::unwrap_or_clone(subquery.subquery),
332 )?
333 .data;
334 Ok(Transformed::yes(Expr::Exists(Exists {
335 subquery: Subquery {
336 subquery: Arc::new(new_plan),
337 outer_ref_columns: subquery.outer_ref_columns,
338 spans: subquery.spans,
339 },
340 negated,
341 })))
342 }
343 Expr::InSubquery(InSubquery {
344 expr,
345 subquery,
346 negated,
347 }) => {
348 let new_plan = analyze_internal(
349 self.schema,
350 Arc::unwrap_or_clone(subquery.subquery),
351 )?
352 .data;
353 let expr_type = expr.get_type(self.schema)?;
354 let subquery_type = new_plan.schema().field(0).data_type();
355 let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(plan_datafusion_err!(
356 "expr type {expr_type:?} can't cast to {subquery_type:?} in InSubquery"
357 ),
358 )?;
359 let new_subquery = Subquery {
360 subquery: Arc::new(new_plan),
361 outer_ref_columns: subquery.outer_ref_columns,
362 spans: subquery.spans,
363 };
364 Ok(Transformed::yes(Expr::InSubquery(InSubquery::new(
365 Box::new(expr.cast_to(&common_type, self.schema)?),
366 cast_subquery(new_subquery, &common_type)?,
367 negated,
368 ))))
369 }
370 Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op(
371 *expr,
372 self.schema,
373 )?))),
374 Expr::IsTrue(expr) => Ok(Transformed::yes(is_true(
375 get_casted_expr_for_bool_op(*expr, self.schema)?,
376 ))),
377 Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true(
378 get_casted_expr_for_bool_op(*expr, self.schema)?,
379 ))),
380 Expr::IsFalse(expr) => Ok(Transformed::yes(is_false(
381 get_casted_expr_for_bool_op(*expr, self.schema)?,
382 ))),
383 Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false(
384 get_casted_expr_for_bool_op(*expr, self.schema)?,
385 ))),
386 Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown(
387 get_casted_expr_for_bool_op(*expr, self.schema)?,
388 ))),
389 Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown(
390 get_casted_expr_for_bool_op(*expr, self.schema)?,
391 ))),
392 Expr::Like(Like {
393 negated,
394 expr,
395 pattern,
396 escape_char,
397 case_insensitive,
398 }) => {
399 let left_type = expr.get_type(self.schema)?;
400 let right_type = pattern.get_type(self.schema)?;
401 let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| {
402 let op_name = if case_insensitive {
403 "ILIKE"
404 } else {
405 "LIKE"
406 };
407 plan_datafusion_err!(
408 "There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression"
409 )
410 })?;
411 let expr = match left_type {
412 DataType::Dictionary(_, inner) if *inner == DataType::Utf8 => expr,
413 _ => Box::new(expr.cast_to(&coerced_type, self.schema)?),
414 };
415 let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?);
416 Ok(Transformed::yes(Expr::Like(Like::new(
417 negated,
418 expr,
419 pattern,
420 escape_char,
421 case_insensitive,
422 ))))
423 }
424 Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
425 let (left, right) =
426 self.coerce_binary_op(*left, self.schema, op, *right, self.schema)?;
427 Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new(
428 Box::new(left),
429 op,
430 Box::new(right),
431 ))))
432 }
433 Expr::Between(Between {
434 expr,
435 negated,
436 low,
437 high,
438 }) => {
439 let expr_type = expr.get_type(self.schema)?;
440 let low_type = low.get_type(self.schema)?;
441 let low_coerced_type = comparison_coercion(&expr_type, &low_type)
442 .ok_or_else(|| {
443 DataFusionError::Internal(format!(
444 "Failed to coerce types {expr_type} and {low_type} in BETWEEN expression"
445 ))
446 })?;
447 let high_type = high.get_type(self.schema)?;
448 let high_coerced_type = comparison_coercion(&expr_type, &high_type)
449 .ok_or_else(|| {
450 DataFusionError::Internal(format!(
451 "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
452 ))
453 })?;
454 let coercion_type =
455 comparison_coercion(&low_coerced_type, &high_coerced_type)
456 .ok_or_else(|| {
457 DataFusionError::Internal(format!(
458 "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
459 ))
460 })?;
461 Ok(Transformed::yes(Expr::Between(Between::new(
462 Box::new(expr.cast_to(&coercion_type, self.schema)?),
463 negated,
464 Box::new(low.cast_to(&coercion_type, self.schema)?),
465 Box::new(high.cast_to(&coercion_type, self.schema)?),
466 ))))
467 }
468 Expr::InList(InList {
469 expr,
470 list,
471 negated,
472 }) => {
473 let expr_data_type = expr.get_type(self.schema)?;
474 let list_data_types = list
475 .iter()
476 .map(|list_expr| list_expr.get_type(self.schema))
477 .collect::<Result<Vec<_>>>()?;
478 let result_type =
479 get_coerce_type_for_list(&expr_data_type, &list_data_types);
480 match result_type {
481 None => plan_err!(
482 "Can not find compatible types to compare {expr_data_type:?} with {list_data_types:?}"
483 ),
484 Some(coerced_type) => {
485 let cast_expr = expr.cast_to(&coerced_type, self.schema)?;
487 let cast_list_expr = list
488 .into_iter()
489 .map(|list_expr| {
490 list_expr.cast_to(&coerced_type, self.schema)
491 })
492 .collect::<Result<Vec<_>>>()?;
493 Ok(Transformed::yes(Expr::InList(InList ::new(
494 Box::new(cast_expr),
495 cast_list_expr,
496 negated,
497 ))))
498 }
499 }
500 }
501 Expr::Case(case) => {
502 let case = coerce_case_expression(case, self.schema)?;
503 Ok(Transformed::yes(Expr::Case(case)))
504 }
505 Expr::ScalarFunction(ScalarFunction { func, args }) => {
506 let new_expr = coerce_arguments_for_signature_with_scalar_udf(
507 args,
508 self.schema,
509 &func,
510 )?;
511 Ok(Transformed::yes(Expr::ScalarFunction(
512 ScalarFunction::new_udf(func, new_expr),
513 )))
514 }
515 Expr::AggregateFunction(expr::AggregateFunction {
516 func,
517 params:
518 AggregateFunctionParams {
519 args,
520 distinct,
521 filter,
522 order_by,
523 null_treatment,
524 },
525 }) => {
526 let new_expr = coerce_arguments_for_signature_with_aggregate_udf(
527 args,
528 self.schema,
529 &func,
530 )?;
531 Ok(Transformed::yes(Expr::AggregateFunction(
532 expr::AggregateFunction::new_udf(
533 func,
534 new_expr,
535 distinct,
536 filter,
537 order_by,
538 null_treatment,
539 ),
540 )))
541 }
542 Expr::WindowFunction(window_fun) => {
543 let WindowFunction {
544 fun,
545 params:
546 expr::WindowFunctionParams {
547 args,
548 partition_by,
549 order_by,
550 window_frame,
551 null_treatment,
552 },
553 } = *window_fun;
554 let window_frame =
555 coerce_window_frame(window_frame, self.schema, &order_by)?;
556
557 let args = match &fun {
558 expr::WindowFunctionDefinition::AggregateUDF(udf) => {
559 coerce_arguments_for_signature_with_aggregate_udf(
560 args,
561 self.schema,
562 udf,
563 )?
564 }
565 _ => args,
566 };
567
568 Ok(Transformed::yes(
569 Expr::from(WindowFunction::new(fun, args))
570 .partition_by(partition_by)
571 .order_by(order_by)
572 .window_frame(window_frame)
573 .null_treatment(null_treatment)
574 .build()?,
575 ))
576 }
577 #[expect(deprecated)]
579 Expr::Alias(_)
580 | Expr::Column(_)
581 | Expr::ScalarVariable(_, _)
582 | Expr::Literal(_, _)
583 | Expr::SimilarTo(_)
584 | Expr::IsNotNull(_)
585 | Expr::IsNull(_)
586 | Expr::Negative(_)
587 | Expr::Cast(_)
588 | Expr::TryCast(_)
589 | Expr::Wildcard { .. }
590 | Expr::GroupingSet(_)
591 | Expr::Placeholder(_)
592 | Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)),
593 }
594 }
595}
596
597fn transform_schema_to_nonview(dfschema: &DFSchemaRef) -> Option<Result<DFSchema>> {
599 let metadata = dfschema.as_arrow().metadata.clone();
600 let mut transformed = false;
601
602 let (qualifiers, transformed_fields): (Vec<Option<TableReference>>, Vec<Arc<Field>>) =
603 dfschema
604 .iter()
605 .map(|(qualifier, field)| match field.data_type() {
606 DataType::Utf8View => {
607 transformed = true;
608 (
609 qualifier.cloned() as Option<TableReference>,
610 Arc::new(Field::new(
611 field.name(),
612 DataType::LargeUtf8,
613 field.is_nullable(),
614 )),
615 )
616 }
617 DataType::BinaryView => {
618 transformed = true;
619 (
620 qualifier.cloned() as Option<TableReference>,
621 Arc::new(Field::new(
622 field.name(),
623 DataType::LargeBinary,
624 field.is_nullable(),
625 )),
626 )
627 }
628 _ => (
629 qualifier.cloned() as Option<TableReference>,
630 Arc::clone(field),
631 ),
632 })
633 .unzip();
634
635 if !transformed {
636 return None;
637 }
638
639 let schema = Schema::new_with_metadata(transformed_fields, metadata);
640 Some(DFSchema::from_field_specific_qualified_schema(
641 qualifiers,
642 &Arc::new(schema),
643 ))
644}
645
646fn coerce_scalar(target_type: &DataType, value: &ScalarValue) -> Result<ScalarValue> {
649 match value {
650 ScalarValue::Utf8(Some(val)) => {
652 ScalarValue::try_from_string(val.clone(), target_type)
653 }
654 s => {
655 if s.is_null() {
656 ScalarValue::try_from(target_type)
658 } else {
659 Ok(s.clone())
663 }
664 }
665 }
666}
667
668fn coerce_scalar_range_aware(
675 target_type: &DataType,
676 value: &ScalarValue,
677) -> Result<ScalarValue> {
678 coerce_scalar(target_type, value).or_else(|err| {
679 if let Some(largest_type) = get_widest_type_in_family(target_type) {
681 coerce_scalar(largest_type, value).map_or_else(
682 |_| exec_err!("Cannot cast {value:?} to {target_type:?}"),
683 |_| ScalarValue::try_from(target_type),
684 )
685 } else {
686 Err(err)
687 }
688 })
689}
690
691fn get_widest_type_in_family(given_type: &DataType) -> Option<&DataType> {
695 match given_type {
696 DataType::UInt8 | DataType::UInt16 | DataType::UInt32 => Some(&DataType::UInt64),
697 DataType::Int8 | DataType::Int16 | DataType::Int32 => Some(&DataType::Int64),
698 DataType::Float16 | DataType::Float32 => Some(&DataType::Float64),
699 _ => None,
700 }
701}
702
703fn coerce_frame_bound(
705 target_type: &DataType,
706 bound: WindowFrameBound,
707) -> Result<WindowFrameBound> {
708 match bound {
709 WindowFrameBound::Preceding(v) => {
710 coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Preceding)
711 }
712 WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow),
713 WindowFrameBound::Following(v) => {
714 coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Following)
715 }
716 }
717}
718
719fn extract_window_frame_target_type(col_type: &DataType) -> Result<DataType> {
720 if col_type.is_numeric()
721 || is_utf8_or_utf8view_or_large_utf8(col_type)
722 || matches!(col_type, DataType::List(_))
723 || matches!(col_type, DataType::LargeList(_))
724 || matches!(col_type, DataType::FixedSizeList(_, _))
725 || matches!(col_type, DataType::Null)
726 || matches!(col_type, DataType::Boolean)
727 {
728 Ok(col_type.clone())
729 } else if is_datetime(col_type) {
730 Ok(DataType::Interval(IntervalUnit::MonthDayNano))
731 } else if let DataType::Dictionary(_, value_type) = col_type {
732 extract_window_frame_target_type(value_type)
733 } else {
734 return internal_err!("Cannot run range queries on datatype: {col_type:?}");
735 }
736}
737
738fn coerce_window_frame(
741 window_frame: WindowFrame,
742 schema: &DFSchema,
743 expressions: &[Sort],
744) -> Result<WindowFrame> {
745 let mut window_frame = window_frame;
746 let target_type = match window_frame.units {
747 WindowFrameUnits::Range => {
748 let current_types = expressions
749 .first()
750 .map(|s| s.expr.get_type(schema))
751 .transpose()?;
752 if let Some(col_type) = current_types {
753 extract_window_frame_target_type(&col_type)?
754 } else {
755 return internal_err!("ORDER BY column cannot be empty");
756 }
757 }
758 WindowFrameUnits::Rows | WindowFrameUnits::Groups => DataType::UInt64,
759 };
760 window_frame.start_bound =
761 coerce_frame_bound(&target_type, window_frame.start_bound)?;
762 window_frame.end_bound = coerce_frame_bound(&target_type, window_frame.end_bound)?;
763 Ok(window_frame)
764}
765
766fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result<Expr> {
769 let left_type = expr.get_type(schema)?;
770 BinaryTypeCoercer::new(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)
771 .get_input_types()?;
772 expr.cast_to(&DataType::Boolean, schema)
773}
774
775fn coerce_arguments_for_signature_with_scalar_udf(
780 expressions: Vec<Expr>,
781 schema: &DFSchema,
782 func: &ScalarUDF,
783) -> Result<Vec<Expr>> {
784 if expressions.is_empty() {
785 return Ok(expressions);
786 }
787
788 let current_types = expressions
789 .iter()
790 .map(|e| e.get_type(schema))
791 .collect::<Result<Vec<_>>>()?;
792
793 let new_types = data_types_with_scalar_udf(¤t_types, func)?;
794
795 expressions
796 .into_iter()
797 .enumerate()
798 .map(|(i, expr)| expr.cast_to(&new_types[i], schema))
799 .collect()
800}
801
802fn coerce_arguments_for_signature_with_aggregate_udf(
807 expressions: Vec<Expr>,
808 schema: &DFSchema,
809 func: &AggregateUDF,
810) -> Result<Vec<Expr>> {
811 if expressions.is_empty() {
812 return Ok(expressions);
813 }
814
815 let current_fields = expressions
816 .iter()
817 .map(|e| e.to_field(schema).map(|(_, f)| f))
818 .collect::<Result<Vec<_>>>()?;
819
820 let new_types = fields_with_aggregate_udf(¤t_fields, func)?
821 .into_iter()
822 .map(|f| f.data_type().clone())
823 .collect::<Vec<_>>();
824
825 expressions
826 .into_iter()
827 .enumerate()
828 .map(|(i, expr)| expr.cast_to(&new_types[i], schema))
829 .collect()
830}
831
832fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result<Case> {
833 let case_type = case
865 .expr
866 .as_ref()
867 .map(|expr| expr.get_type(schema))
868 .transpose()?;
869 let then_types = case
870 .when_then_expr
871 .iter()
872 .map(|(_when, then)| then.get_type(schema))
873 .collect::<Result<Vec<_>>>()?;
874 let else_type = case
875 .else_expr
876 .as_ref()
877 .map(|expr| expr.get_type(schema))
878 .transpose()?;
879
880 let case_when_coerce_type = case_type
882 .as_ref()
883 .map(|case_type| {
884 let when_types = case
885 .when_then_expr
886 .iter()
887 .map(|(when, _then)| when.get_type(schema))
888 .collect::<Result<Vec<_>>>()?;
889 let coerced_type =
890 get_coerce_type_for_case_expression(&when_types, Some(case_type));
891 coerced_type.ok_or_else(|| {
892 plan_datafusion_err!(
893 "Failed to coerce case ({case_type:?}) and when ({when_types:?}) \
894 to common types in CASE WHEN expression"
895 )
896 })
897 })
898 .transpose()?;
899 let then_else_coerce_type =
900 get_coerce_type_for_case_expression(&then_types, else_type.as_ref()).ok_or_else(
901 || {
902 plan_datafusion_err!(
903 "Failed to coerce then ({then_types:?}) and else ({else_type:?}) \
904 to common types in CASE WHEN expression"
905 )
906 },
907 )?;
908
909 let case_expr = case
911 .expr
912 .zip(case_when_coerce_type.as_ref())
913 .map(|(case_expr, coercible_type)| case_expr.cast_to(coercible_type, schema))
914 .transpose()?
915 .map(Box::new);
916 let when_then = case
917 .when_then_expr
918 .into_iter()
919 .map(|(when, then)| {
920 let when_type = case_when_coerce_type.as_ref().unwrap_or(&DataType::Boolean);
921 let when = when.cast_to(when_type, schema).map_err(|e| {
922 DataFusionError::Context(
923 format!(
924 "WHEN expressions in CASE couldn't be \
925 converted to common type ({when_type})"
926 ),
927 Box::new(e),
928 )
929 })?;
930 let then = then.cast_to(&then_else_coerce_type, schema)?;
931 Ok((Box::new(when), Box::new(then)))
932 })
933 .collect::<Result<Vec<_>>>()?;
934 let else_expr = case
935 .else_expr
936 .map(|expr| expr.cast_to(&then_else_coerce_type, schema))
937 .transpose()?
938 .map(Box::new);
939
940 Ok(Case::new(case_expr, when_then, else_expr))
941}
942
943pub fn coerce_union_schema(inputs: &[Arc<LogicalPlan>]) -> Result<DFSchema> {
948 coerce_union_schema_with_schema(&inputs[1..], inputs[0].schema())
949}
950fn coerce_union_schema_with_schema(
951 inputs: &[Arc<LogicalPlan>],
952 base_schema: &DFSchemaRef,
953) -> Result<DFSchema> {
954 let mut union_datatypes = base_schema
955 .fields()
956 .iter()
957 .map(|f| f.data_type().clone())
958 .collect::<Vec<_>>();
959 let mut union_nullabilities = base_schema
960 .fields()
961 .iter()
962 .map(|f| f.is_nullable())
963 .collect::<Vec<_>>();
964 let mut union_field_meta = base_schema
965 .fields()
966 .iter()
967 .map(|f| f.metadata().clone())
968 .collect::<Vec<_>>();
969
970 let mut metadata = base_schema.metadata().clone();
971
972 for (i, plan) in inputs.iter().enumerate() {
973 let plan_schema = plan.schema();
974 metadata.extend(plan_schema.metadata().clone());
975
976 if plan_schema.fields().len() != base_schema.fields().len() {
977 return plan_err!(
978 "Union schemas have different number of fields: \
979 query 1 has {} fields whereas query {} has {} fields",
980 base_schema.fields().len(),
981 i + 1,
982 plan_schema.fields().len()
983 );
984 }
985
986 for (union_datatype, union_nullable, union_field_map, plan_field) in izip!(
988 union_datatypes.iter_mut(),
989 union_nullabilities.iter_mut(),
990 union_field_meta.iter_mut(),
991 plan_schema.fields().iter()
992 ) {
993 let coerced_type =
994 comparison_coercion(union_datatype, plan_field.data_type()).ok_or_else(
995 || {
996 plan_datafusion_err!(
997 "Incompatible inputs for Union: Previous inputs were \
998 of type {}, but got incompatible type {} on column '{}'",
999 union_datatype,
1000 plan_field.data_type(),
1001 plan_field.name()
1002 )
1003 },
1004 )?;
1005
1006 *union_datatype = coerced_type;
1007 *union_nullable = *union_nullable || plan_field.is_nullable();
1008 union_field_map.extend(plan_field.metadata().clone());
1009 }
1010 }
1011 let union_qualified_fields = izip!(
1012 base_schema.fields(),
1013 union_datatypes.into_iter(),
1014 union_nullabilities,
1015 union_field_meta.into_iter()
1016 )
1017 .map(|(field, datatype, nullable, metadata)| {
1018 let mut field = Field::new(field.name().clone(), datatype, nullable);
1019 field.set_metadata(metadata);
1020 (None, field.into())
1021 })
1022 .collect::<Vec<_>>();
1023
1024 DFSchema::new_with_metadata(union_qualified_fields, metadata)
1025}
1026
1027fn project_with_column_index(
1029 expr: Vec<Expr>,
1030 input: Arc<LogicalPlan>,
1031 schema: DFSchemaRef,
1032) -> Result<LogicalPlan> {
1033 let alias_expr = expr
1034 .into_iter()
1035 .enumerate()
1036 .map(|(i, e)| match e {
1037 Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => {
1038 Ok(e.unalias().alias(schema.field(i).name()))
1039 }
1040 Expr::Column(Column {
1041 relation: _,
1042 ref name,
1043 spans: _,
1044 }) if name != schema.field(i).name() => Ok(e.alias(schema.field(i).name())),
1045 Expr::Alias { .. } | Expr::Column { .. } => Ok(e),
1046 #[expect(deprecated)]
1047 Expr::Wildcard { .. } => {
1048 plan_err!("Wildcard should be expanded before type coercion")
1049 }
1050 _ => Ok(e.alias(schema.field(i).name())),
1051 })
1052 .collect::<Result<Vec<_>>>()?;
1053
1054 Projection::try_new_with_schema(alias_expr, input, schema)
1055 .map(LogicalPlan::Projection)
1056}
1057
1058#[cfg(test)]
1059mod test {
1060 use std::any::Any;
1061 use std::sync::Arc;
1062
1063 use arrow::datatypes::DataType::Utf8;
1064 use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder, TimeUnit};
1065 use insta::assert_snapshot;
1066
1067 use crate::analyzer::type_coercion::{
1068 coerce_case_expression, TypeCoercion, TypeCoercionRewriter,
1069 };
1070 use crate::analyzer::Analyzer;
1071 use crate::assert_analyzed_plan_with_config_eq_snapshot;
1072 use datafusion_common::config::ConfigOptions;
1073 use datafusion_common::tree_node::{TransformedResult, TreeNode};
1074 use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue, Spans};
1075 use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction};
1076 use datafusion_expr::logical_plan::{EmptyRelation, Projection, Sort};
1077 use datafusion_expr::test::function_stub::avg_udaf;
1078 use datafusion_expr::{
1079 cast, col, create_udaf, is_true, lit, AccumulatorFactoryFunction, AggregateUDF,
1080 BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Filter, LogicalPlan,
1081 Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
1082 SimpleAggregateUDF, Subquery, Union, Volatility,
1083 };
1084 use datafusion_functions_aggregate::average::AvgAccumulator;
1085 use datafusion_sql::TableReference;
1086
1087 fn empty() -> Arc<LogicalPlan> {
1088 Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1089 produce_one_row: false,
1090 schema: Arc::new(DFSchema::empty()),
1091 }))
1092 }
1093
1094 fn empty_with_type(data_type: DataType) -> Arc<LogicalPlan> {
1095 Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1096 produce_one_row: false,
1097 schema: Arc::new(
1098 DFSchema::from_unqualified_fields(
1099 vec![Field::new("a", data_type, true)].into(),
1100 std::collections::HashMap::new(),
1101 )
1102 .unwrap(),
1103 ),
1104 }))
1105 }
1106
1107 macro_rules! assert_analyzed_plan_eq {
1108 (
1109 $plan: expr,
1110 @ $expected: literal $(,)?
1111 ) => {{
1112 let options = ConfigOptions::default();
1113 let rule = Arc::new(TypeCoercion::new());
1114 assert_analyzed_plan_with_config_eq_snapshot!(
1115 options,
1116 rule,
1117 $plan,
1118 @ $expected,
1119 )
1120 }};
1121 }
1122
1123 macro_rules! coerce_on_output_if_viewtype {
1124 (
1125 $is_viewtype: expr,
1126 $plan: expr,
1127 @ $expected: literal $(,)?
1128 ) => {{
1129 let mut options = ConfigOptions::default();
1130 if $is_viewtype {options.optimizer.expand_views_at_output = true;}
1132 let rule = Arc::new(TypeCoercion::new());
1133
1134 assert_analyzed_plan_with_config_eq_snapshot!(
1135 options,
1136 rule,
1137 $plan,
1138 @ $expected,
1139 )
1140 }};
1141 }
1142
1143 fn assert_type_coercion_error(
1144 plan: LogicalPlan,
1145 expected_substr: &str,
1146 ) -> Result<()> {
1147 let options = ConfigOptions::default();
1148 let analyzer = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]);
1149
1150 match analyzer.execute_and_check(plan, &options, |_, _| {}) {
1151 Ok(succeeded_plan) => {
1152 panic!(
1153 "Expected a type coercion error, but analysis succeeded: \n{succeeded_plan:#?}"
1154 );
1155 }
1156 Err(e) => {
1157 let msg = e.to_string();
1158 assert!(
1159 msg.contains(expected_substr),
1160 "Error did not contain expected substring.\n expected to find: `{expected_substr}`\n actual error: `{msg}`"
1161 );
1162 }
1163 }
1164
1165 Ok(())
1166 }
1167
1168 #[test]
1169 fn simple_case() -> Result<()> {
1170 let expr = col("a").lt(lit(2_u32));
1171 let empty = empty_with_type(DataType::Float64);
1172 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1173
1174 assert_analyzed_plan_eq!(
1175 plan,
1176 @r"
1177 Projection: a < CAST(UInt32(2) AS Float64)
1178 EmptyRelation
1179 "
1180 )
1181 }
1182
1183 #[test]
1184 fn test_coerce_union() -> Result<()> {
1185 let left_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1186 produce_one_row: false,
1187 schema: Arc::new(
1188 DFSchema::try_from_qualified_schema(
1189 TableReference::full("datafusion", "test", "foo"),
1190 &Schema::new(vec![Field::new("a", DataType::Int32, false)]),
1191 )
1192 .unwrap(),
1193 ),
1194 }));
1195 let right_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1196 produce_one_row: false,
1197 schema: Arc::new(
1198 DFSchema::try_from_qualified_schema(
1199 TableReference::full("datafusion", "test", "foo"),
1200 &Schema::new(vec![Field::new("a", DataType::Int64, false)]),
1201 )
1202 .unwrap(),
1203 ),
1204 }));
1205 let union = LogicalPlan::Union(Union::try_new_with_loose_types(vec![
1206 left_plan, right_plan,
1207 ])?);
1208 let analyzed_union = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())])
1209 .execute_and_check(union, &ConfigOptions::default(), |_, _| {})?;
1210 let top_level_plan = LogicalPlan::Projection(Projection::try_new(
1211 vec![col("a")],
1212 Arc::new(analyzed_union),
1213 )?);
1214
1215 assert_analyzed_plan_eq!(
1216 top_level_plan,
1217 @r"
1218 Projection: a
1219 Union
1220 Projection: CAST(datafusion.test.foo.a AS Int64) AS a
1221 EmptyRelation
1222 EmptyRelation
1223 "
1224 )
1225 }
1226
1227 #[test]
1228 fn coerce_utf8view_output() -> Result<()> {
1229 let expr = col("a");
1232 let empty = empty_with_type(DataType::Utf8View);
1233 let plan = LogicalPlan::Projection(Projection::try_new(
1234 vec![expr.clone()],
1235 Arc::clone(&empty),
1236 )?);
1237
1238 coerce_on_output_if_viewtype!(
1240 false,
1241 plan.clone(),
1242 @r"
1243 Projection: a
1244 EmptyRelation
1245 "
1246 )?;
1247
1248 coerce_on_output_if_viewtype!(
1250 true,
1251 plan.clone(),
1252 @r"
1253 Projection: CAST(a AS LargeUtf8)
1254 EmptyRelation
1255 "
1256 )?;
1257
1258 let bool_expr = col("a").lt(lit("foo"));
1261 let bool_plan = LogicalPlan::Projection(Projection::try_new(
1262 vec![bool_expr],
1263 Arc::clone(&empty),
1264 )?);
1265 coerce_on_output_if_viewtype!(
1267 false,
1268 bool_plan.clone(),
1269 @r#"
1270 Projection: a < CAST(Utf8("foo") AS Utf8View)
1271 EmptyRelation
1272 "#
1273 )?;
1274
1275 coerce_on_output_if_viewtype!(
1276 false,
1277 plan.clone(),
1278 @r"
1279 Projection: a
1280 EmptyRelation
1281 "
1282 )?;
1283
1284 coerce_on_output_if_viewtype!(
1286 true,
1287 plan.clone(),
1288 @r"
1289 Projection: CAST(a AS LargeUtf8)
1290 EmptyRelation
1291 "
1292 )?;
1293
1294 let sort_expr = expr.sort(true, true);
1297 let sort_plan = LogicalPlan::Sort(Sort {
1298 expr: vec![sort_expr],
1299 input: Arc::new(plan),
1300 fetch: None,
1301 });
1302
1303 coerce_on_output_if_viewtype!(
1305 false,
1306 sort_plan.clone(),
1307 @r"
1308 Sort: a ASC NULLS FIRST
1309 Projection: a
1310 EmptyRelation
1311 "
1312 )?;
1313
1314 coerce_on_output_if_viewtype!(
1316 true,
1317 sort_plan.clone(),
1318 @r"
1319 Projection: CAST(a AS LargeUtf8)
1320 Sort: a ASC NULLS FIRST
1321 Projection: a
1322 EmptyRelation
1323 "
1324 )?;
1325
1326 let plan = LogicalPlan::Projection(Projection::try_new(
1329 vec![col("a")],
1330 Arc::new(sort_plan),
1331 )?);
1332 coerce_on_output_if_viewtype!(
1334 false,
1335 plan.clone(),
1336 @r"
1337 Projection: a
1338 Sort: a ASC NULLS FIRST
1339 Projection: a
1340 EmptyRelation
1341 "
1342 )?;
1343 coerce_on_output_if_viewtype!(
1345 true,
1346 plan.clone(),
1347 @r"
1348 Projection: CAST(a AS LargeUtf8)
1349 Sort: a ASC NULLS FIRST
1350 Projection: a
1351 EmptyRelation
1352 "
1353 )?;
1354
1355 Ok(())
1356 }
1357
1358 #[test]
1359 fn coerce_binaryview_output() -> Result<()> {
1360 let expr = col("a");
1363 let empty = empty_with_type(DataType::BinaryView);
1364 let plan = LogicalPlan::Projection(Projection::try_new(
1365 vec![expr.clone()],
1366 Arc::clone(&empty),
1367 )?);
1368
1369 coerce_on_output_if_viewtype!(
1371 false,
1372 plan.clone(),
1373 @r"
1374 Projection: a
1375 EmptyRelation
1376 "
1377 )?;
1378
1379 coerce_on_output_if_viewtype!(
1381 true,
1382 plan.clone(),
1383 @r"
1384 Projection: CAST(a AS LargeBinary)
1385 EmptyRelation
1386 "
1387 )?;
1388
1389 let bool_expr = col("a").lt(lit(vec![8, 1, 8, 1]));
1392 let bool_plan = LogicalPlan::Projection(Projection::try_new(
1393 vec![bool_expr],
1394 Arc::clone(&empty),
1395 )?);
1396
1397 coerce_on_output_if_viewtype!(
1399 false,
1400 bool_plan.clone(),
1401 @r#"
1402 Projection: a < CAST(Binary("8,1,8,1") AS BinaryView)
1403 EmptyRelation
1404 "#
1405 )?;
1406
1407 coerce_on_output_if_viewtype!(
1409 true,
1410 bool_plan.clone(),
1411 @r#"
1412 Projection: a < CAST(Binary("8,1,8,1") AS BinaryView)
1413 EmptyRelation
1414 "#
1415 )?;
1416
1417 let sort_expr = expr.sort(true, true);
1420 let sort_plan = LogicalPlan::Sort(Sort {
1421 expr: vec![sort_expr],
1422 input: Arc::new(plan),
1423 fetch: None,
1424 });
1425
1426 coerce_on_output_if_viewtype!(
1428 false,
1429 sort_plan.clone(),
1430 @r"
1431 Sort: a ASC NULLS FIRST
1432 Projection: a
1433 EmptyRelation
1434 "
1435 )?;
1436 coerce_on_output_if_viewtype!(
1438 true,
1439 sort_plan.clone(),
1440 @r"
1441 Projection: CAST(a AS LargeBinary)
1442 Sort: a ASC NULLS FIRST
1443 Projection: a
1444 EmptyRelation
1445 "
1446 )?;
1447
1448 let plan = LogicalPlan::Projection(Projection::try_new(
1451 vec![col("a")],
1452 Arc::new(sort_plan),
1453 )?);
1454
1455 coerce_on_output_if_viewtype!(
1457 false,
1458 plan.clone(),
1459 @r"
1460 Projection: a
1461 Sort: a ASC NULLS FIRST
1462 Projection: a
1463 EmptyRelation
1464 "
1465 )?;
1466
1467 coerce_on_output_if_viewtype!(
1469 true,
1470 plan.clone(),
1471 @r"
1472 Projection: CAST(a AS LargeBinary)
1473 Sort: a ASC NULLS FIRST
1474 Projection: a
1475 EmptyRelation
1476 "
1477 )?;
1478
1479 Ok(())
1480 }
1481
1482 #[test]
1483 fn nested_case() -> Result<()> {
1484 let expr = col("a").lt(lit(2_u32));
1485 let empty = empty_with_type(DataType::Float64);
1486
1487 let plan = LogicalPlan::Projection(Projection::try_new(
1488 vec![expr.clone().or(expr)],
1489 empty,
1490 )?);
1491
1492 assert_analyzed_plan_eq!(
1493 plan,
1494 @r"
1495 Projection: a < CAST(UInt32(2) AS Float64) OR a < CAST(UInt32(2) AS Float64)
1496 EmptyRelation
1497 "
1498 )
1499 }
1500
1501 #[derive(Debug, Clone)]
1502 struct TestScalarUDF {
1503 signature: Signature,
1504 }
1505
1506 impl ScalarUDFImpl for TestScalarUDF {
1507 fn as_any(&self) -> &dyn Any {
1508 self
1509 }
1510
1511 fn name(&self) -> &str {
1512 "TestScalarUDF"
1513 }
1514
1515 fn signature(&self) -> &Signature {
1516 &self.signature
1517 }
1518
1519 fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1520 Ok(Utf8)
1521 }
1522
1523 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1524 Ok(ColumnarValue::Scalar(ScalarValue::from("a")))
1525 }
1526 }
1527
1528 #[test]
1529 fn scalar_udf() -> Result<()> {
1530 let empty = empty();
1531
1532 let udf = ScalarUDF::from(TestScalarUDF {
1533 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1534 })
1535 .call(vec![lit(123_i32)]);
1536 let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?);
1537
1538 assert_analyzed_plan_eq!(
1539 plan,
1540 @r"
1541 Projection: TestScalarUDF(CAST(Int32(123) AS Float32))
1542 EmptyRelation
1543 "
1544 )
1545 }
1546
1547 #[test]
1548 fn scalar_udf_invalid_input() -> Result<()> {
1549 let empty = empty();
1550 let udf = ScalarUDF::from(TestScalarUDF {
1551 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1552 })
1553 .call(vec![lit("Apple")]);
1554 Projection::try_new(vec![udf], empty)
1555 .expect_err("Expected an error due to incorrect function input");
1556
1557 Ok(())
1558 }
1559
1560 #[test]
1561 fn scalar_function() -> Result<()> {
1562 let empty = empty();
1564 let lit_expr = lit(10i64);
1565 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
1566 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1567 });
1568 let scalar_function_expr =
1569 Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![lit_expr]));
1570 let plan = LogicalPlan::Projection(Projection::try_new(
1571 vec![scalar_function_expr],
1572 empty,
1573 )?);
1574
1575 assert_analyzed_plan_eq!(
1576 plan,
1577 @r"
1578 Projection: TestScalarUDF(CAST(Int64(10) AS Float32))
1579 EmptyRelation
1580 "
1581 )
1582 }
1583
1584 #[test]
1585 fn agg_udaf() -> Result<()> {
1586 let empty = empty();
1587 let my_avg = create_udaf(
1588 "MY_AVG",
1589 vec![DataType::Float64],
1590 Arc::new(DataType::Float64),
1591 Volatility::Immutable,
1592 Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
1593 Arc::new(vec![DataType::UInt64, DataType::Float64]),
1594 );
1595 let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1596 Arc::new(my_avg),
1597 vec![lit(10i64)],
1598 false,
1599 None,
1600 vec![],
1601 None,
1602 ));
1603 let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?);
1604
1605 assert_analyzed_plan_eq!(
1606 plan,
1607 @r"
1608 Projection: MY_AVG(CAST(Int64(10) AS Float64))
1609 EmptyRelation
1610 "
1611 )
1612 }
1613
1614 #[test]
1615 fn agg_udaf_invalid_input() -> Result<()> {
1616 let empty = empty();
1617 let return_type = DataType::Float64;
1618 let accumulator: AccumulatorFactoryFunction =
1619 Arc::new(|_| Ok(Box::<AvgAccumulator>::default()));
1620 let my_avg = AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
1621 "MY_AVG",
1622 Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
1623 return_type,
1624 accumulator,
1625 vec![
1626 Field::new("count", DataType::UInt64, true).into(),
1627 Field::new("avg", DataType::Float64, true).into(),
1628 ],
1629 ));
1630 let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1631 Arc::new(my_avg),
1632 vec![lit("10")],
1633 false,
1634 None,
1635 vec![],
1636 None,
1637 ));
1638
1639 let err = Projection::try_new(vec![udaf], empty).err().unwrap();
1640 assert!(
1641 err.strip_backtrace().starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'MY_AVG' function: coercion from [Utf8] to the signature Uniform(1, [Float64]) failed")
1642 );
1643 Ok(())
1644 }
1645
1646 #[test]
1647 fn agg_function_case() -> Result<()> {
1648 let empty = empty();
1649 let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1650 avg_udaf(),
1651 vec![lit(12f64)],
1652 false,
1653 None,
1654 vec![],
1655 None,
1656 ));
1657 let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1658
1659 assert_analyzed_plan_eq!(
1660 plan,
1661 @r"
1662 Projection: avg(Float64(12))
1663 EmptyRelation
1664 "
1665 )?;
1666
1667 let empty = empty_with_type(DataType::Int32);
1668 let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1669 avg_udaf(),
1670 vec![cast(col("a"), DataType::Float64)],
1671 false,
1672 None,
1673 vec![],
1674 None,
1675 ));
1676 let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1677
1678 assert_analyzed_plan_eq!(
1679 plan,
1680 @r"
1681 Projection: avg(CAST(a AS Float64))
1682 EmptyRelation
1683 "
1684 )
1685 }
1686
1687 #[test]
1688 fn agg_function_invalid_input_avg() -> Result<()> {
1689 let empty = empty();
1690 let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1691 avg_udaf(),
1692 vec![lit("1")],
1693 false,
1694 None,
1695 vec![],
1696 None,
1697 ));
1698 let err = Projection::try_new(vec![agg_expr], empty)
1699 .err()
1700 .unwrap()
1701 .strip_backtrace();
1702 assert!(err.starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'avg' function: coercion from [Utf8] to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed"));
1703 Ok(())
1704 }
1705
1706 #[test]
1707 fn binary_op_date32_op_interval() -> Result<()> {
1708 let expr = cast(lit("1998-03-18"), DataType::Date32)
1710 + lit(ScalarValue::new_interval_dt(123, 456));
1711 let empty = empty();
1712 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1713
1714 assert_analyzed_plan_eq!(
1715 plan,
1716 @r#"
1717 Projection: CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("IntervalDayTime { days: 123, milliseconds: 456 }")
1718 EmptyRelation
1719 "#
1720 )
1721 }
1722
1723 #[test]
1724 fn inlist_case() -> Result<()> {
1725 let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1727 let empty = empty_with_type(DataType::Int64);
1728 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1729 assert_analyzed_plan_eq!(
1730 plan,
1731 @r"
1732 Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)])
1733 EmptyRelation
1734 ")?;
1735
1736 let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1738 let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1739 produce_one_row: false,
1740 schema: Arc::new(DFSchema::from_unqualified_fields(
1741 vec![Field::new("a", DataType::Decimal128(12, 4), true)].into(),
1742 std::collections::HashMap::new(),
1743 )?),
1744 }));
1745 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1746 assert_analyzed_plan_eq!(
1747 plan,
1748 @r"
1749 Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))])
1750 EmptyRelation
1751 ")
1752 }
1753
1754 #[test]
1755 fn between_case() -> Result<()> {
1756 let expr = col("a").between(
1757 lit("2002-05-08"),
1758 cast(lit("2002-05-08"), DataType::Date32)
1760 + lit(ScalarValue::new_interval_ym(0, 1)),
1761 );
1762 let empty = empty_with_type(Utf8);
1763 let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1764
1765 assert_analyzed_plan_eq!(
1766 plan,
1767 @r#"
1768 Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) AND CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1")
1769 EmptyRelation
1770 "#
1771 )
1772 }
1773
1774 #[test]
1775 fn between_infer_cheap_type() -> Result<()> {
1776 let expr = col("a").between(
1777 cast(lit("2002-05-08"), DataType::Date32)
1779 + lit(ScalarValue::new_interval_ym(0, 1)),
1780 lit("2002-12-08"),
1781 );
1782 let empty = empty_with_type(Utf8);
1783 let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1784
1785 assert_analyzed_plan_eq!(
1787 plan,
1788 @r#"
1789 Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1") AND CAST(Utf8("2002-12-08") AS Date32)
1790 EmptyRelation
1791 "#
1792 )
1793 }
1794
1795 #[test]
1796 fn between_null() -> Result<()> {
1797 let expr = lit(ScalarValue::Null).between(lit(ScalarValue::Null), lit(2i64));
1798 let empty = empty();
1799 let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1800
1801 assert_analyzed_plan_eq!(
1802 plan,
1803 @r"
1804 Filter: CAST(NULL AS Int64) BETWEEN CAST(NULL AS Int64) AND Int64(2)
1805 EmptyRelation
1806 "
1807 )
1808 }
1809
1810 #[test]
1811 fn is_bool_for_type_coercion() -> Result<()> {
1812 let expr = col("a").is_true();
1814 let empty = empty_with_type(DataType::Boolean);
1815 let plan =
1816 LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
1817
1818 assert_analyzed_plan_eq!(
1819 plan,
1820 @r"
1821 Projection: a IS TRUE
1822 EmptyRelation
1823 "
1824 )?;
1825
1826 let empty = empty_with_type(DataType::Int64);
1827 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1828 assert_type_coercion_error(
1829 plan,
1830 "Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean"
1831 )?;
1832
1833 let expr = col("a").is_not_true();
1835 let empty = empty_with_type(DataType::Boolean);
1836 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1837
1838 assert_analyzed_plan_eq!(
1839 plan,
1840 @r"
1841 Projection: a IS NOT TRUE
1842 EmptyRelation
1843 "
1844 )?;
1845
1846 let expr = col("a").is_false();
1848 let empty = empty_with_type(DataType::Boolean);
1849 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1850
1851 assert_analyzed_plan_eq!(
1852 plan,
1853 @r"
1854 Projection: a IS FALSE
1855 EmptyRelation
1856 "
1857 )?;
1858
1859 let expr = col("a").is_not_false();
1861 let empty = empty_with_type(DataType::Boolean);
1862 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1863
1864 assert_analyzed_plan_eq!(
1865 plan,
1866 @r"
1867 Projection: a IS NOT FALSE
1868 EmptyRelation
1869 "
1870 )
1871 }
1872
1873 #[test]
1874 fn like_for_type_coercion() -> Result<()> {
1875 let expr = Box::new(col("a"));
1877 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1878 let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1879 let empty = empty_with_type(Utf8);
1880 let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1881
1882 assert_analyzed_plan_eq!(
1883 plan,
1884 @r#"
1885 Projection: a LIKE Utf8("abc")
1886 EmptyRelation
1887 "#
1888 )?;
1889
1890 let expr = Box::new(col("a"));
1891 let pattern = Box::new(lit(ScalarValue::Null));
1892 let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1893 let empty = empty_with_type(Utf8);
1894 let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1895
1896 assert_analyzed_plan_eq!(
1897 plan,
1898 @r"
1899 Projection: a LIKE CAST(NULL AS Utf8)
1900 EmptyRelation
1901 "
1902 )?;
1903
1904 let expr = Box::new(col("a"));
1905 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1906 let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1907 let empty = empty_with_type(DataType::Int64);
1908 let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1909 assert_type_coercion_error(
1910 plan,
1911 "There isn't a common type to coerce Int64 and Utf8 in LIKE expression",
1912 )?;
1913
1914 let expr = Box::new(col("a"));
1916 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1917 let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
1918 let empty = empty_with_type(Utf8);
1919 let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
1920
1921 assert_analyzed_plan_eq!(
1922 plan,
1923 @r#"
1924 Projection: a ILIKE Utf8("abc")
1925 EmptyRelation
1926 "#
1927 )?;
1928
1929 let expr = Box::new(col("a"));
1930 let pattern = Box::new(lit(ScalarValue::Null));
1931 let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
1932 let empty = empty_with_type(Utf8);
1933 let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
1934
1935 assert_analyzed_plan_eq!(
1936 plan,
1937 @r"
1938 Projection: a ILIKE CAST(NULL AS Utf8)
1939 EmptyRelation
1940 "
1941 )?;
1942
1943 let expr = Box::new(col("a"));
1944 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1945 let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
1946 let empty = empty_with_type(DataType::Int64);
1947 let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
1948 assert_type_coercion_error(
1949 plan,
1950 "There isn't a common type to coerce Int64 and Utf8 in ILIKE expression",
1951 )?;
1952
1953 Ok(())
1954 }
1955
1956 #[test]
1957 fn unknown_for_type_coercion() -> Result<()> {
1958 let expr = col("a").is_unknown();
1960 let empty = empty_with_type(DataType::Boolean);
1961 let plan =
1962 LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
1963
1964 assert_analyzed_plan_eq!(
1965 plan,
1966 @r"
1967 Projection: a IS UNKNOWN
1968 EmptyRelation
1969 "
1970 )?;
1971
1972 let empty = empty_with_type(Utf8);
1973 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1974 assert_type_coercion_error(
1975 plan,
1976 "Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean"
1977 )?;
1978
1979 let expr = col("a").is_not_unknown();
1981 let empty = empty_with_type(DataType::Boolean);
1982 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1983
1984 assert_analyzed_plan_eq!(
1985 plan,
1986 @r"
1987 Projection: a IS NOT UNKNOWN
1988 EmptyRelation
1989 "
1990 )
1991 }
1992
1993 #[test]
1994 fn concat_for_type_coercion() -> Result<()> {
1995 let empty = empty_with_type(Utf8);
1996 let args = [col("a"), lit("b"), lit(true), lit(false), lit(13)];
1997
1998 let expr = ScalarUDF::new_from_impl(TestScalarUDF {
2000 signature: Signature::variadic(vec![Utf8], Volatility::Immutable),
2001 })
2002 .call(args.to_vec());
2003 let plan =
2004 LogicalPlan::Projection(Projection::try_new(vec![expr], Arc::clone(&empty))?);
2005 assert_analyzed_plan_eq!(
2006 plan,
2007 @r#"
2008 Projection: TestScalarUDF(a, Utf8("b"), CAST(Boolean(true) AS Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS Utf8))
2009 EmptyRelation
2010 "#
2011 )
2012 }
2013
2014 #[test]
2015 fn test_type_coercion_rewrite() -> Result<()> {
2016 let schema = Arc::new(DFSchema::from_unqualified_fields(
2018 vec![Field::new("a", DataType::Int64, true)].into(),
2019 std::collections::HashMap::new(),
2020 )?);
2021 let mut rewriter = TypeCoercionRewriter { schema: &schema };
2022 let expr = is_true(lit(12i32).gt(lit(13i64)));
2023 let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64)));
2024 let result = expr.rewrite(&mut rewriter).data()?;
2025 assert_eq!(expected, result);
2026
2027 let schema = Arc::new(DFSchema::from_unqualified_fields(
2029 vec![Field::new("a", DataType::Int64, true)].into(),
2030 std::collections::HashMap::new(),
2031 )?);
2032 let mut rewriter = TypeCoercionRewriter { schema: &schema };
2033 let expr = is_true(lit(12i32).eq(lit(13i64)));
2034 let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64)));
2035 let result = expr.rewrite(&mut rewriter).data()?;
2036 assert_eq!(expected, result);
2037
2038 let schema = Arc::new(DFSchema::from_unqualified_fields(
2040 vec![Field::new("a", DataType::Int64, true)].into(),
2041 std::collections::HashMap::new(),
2042 )?);
2043 let mut rewriter = TypeCoercionRewriter { schema: &schema };
2044 let expr = is_true(lit(12i32).lt(lit(13i64)));
2045 let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64)));
2046 let result = expr.rewrite(&mut rewriter).data()?;
2047 assert_eq!(expected, result);
2048
2049 Ok(())
2050 }
2051
2052 #[test]
2053 fn binary_op_date32_eq_ts() -> Result<()> {
2054 let expr = cast(
2055 lit("1998-03-18"),
2056 DataType::Timestamp(TimeUnit::Nanosecond, None),
2057 )
2058 .eq(cast(lit("1998-03-18"), DataType::Date32));
2059 let empty = empty();
2060 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2061
2062 assert_analyzed_plan_eq!(
2063 plan,
2064 @r#"
2065 Projection: CAST(Utf8("1998-03-18") AS Timestamp(Nanosecond, None)) = CAST(CAST(Utf8("1998-03-18") AS Date32) AS Timestamp(Nanosecond, None))
2066 EmptyRelation
2067 "#
2068 )
2069 }
2070
2071 fn cast_if_not_same_type(
2072 expr: Box<Expr>,
2073 data_type: &DataType,
2074 schema: &DFSchemaRef,
2075 ) -> Box<Expr> {
2076 if &expr.get_type(schema).unwrap() != data_type {
2077 Box::new(cast(*expr, data_type.clone()))
2078 } else {
2079 expr
2080 }
2081 }
2082
2083 fn cast_helper(
2084 case: Case,
2085 case_when_type: &DataType,
2086 then_else_type: &DataType,
2087 schema: &DFSchemaRef,
2088 ) -> Case {
2089 let expr = case
2090 .expr
2091 .map(|e| cast_if_not_same_type(e, case_when_type, schema));
2092 let when_then_expr = case
2093 .when_then_expr
2094 .into_iter()
2095 .map(|(when, then)| {
2096 (
2097 cast_if_not_same_type(when, case_when_type, schema),
2098 cast_if_not_same_type(then, then_else_type, schema),
2099 )
2100 })
2101 .collect::<Vec<_>>();
2102 let else_expr = case
2103 .else_expr
2104 .map(|e| cast_if_not_same_type(e, then_else_type, schema));
2105
2106 Case {
2107 expr,
2108 when_then_expr,
2109 else_expr,
2110 }
2111 }
2112
2113 #[test]
2114 fn test_case_expression_coercion() -> Result<()> {
2115 let schema = Arc::new(DFSchema::from_unqualified_fields(
2116 vec![
2117 Field::new("boolean", DataType::Boolean, true),
2118 Field::new("integer", DataType::Int32, true),
2119 Field::new("float", DataType::Float32, true),
2120 Field::new(
2121 "timestamp",
2122 DataType::Timestamp(TimeUnit::Nanosecond, None),
2123 true,
2124 ),
2125 Field::new("date", DataType::Date32, true),
2126 Field::new(
2127 "interval",
2128 DataType::Interval(arrow::datatypes::IntervalUnit::MonthDayNano),
2129 true,
2130 ),
2131 Field::new("binary", DataType::Binary, true),
2132 Field::new("string", Utf8, true),
2133 Field::new("decimal", DataType::Decimal128(10, 10), true),
2134 ]
2135 .into(),
2136 std::collections::HashMap::new(),
2137 )?);
2138
2139 let case = Case {
2140 expr: None,
2141 when_then_expr: vec![
2142 (Box::new(col("boolean")), Box::new(col("integer"))),
2143 (Box::new(col("integer")), Box::new(col("float"))),
2144 (Box::new(col("string")), Box::new(col("string"))),
2145 ],
2146 else_expr: None,
2147 };
2148 let case_when_common_type = DataType::Boolean;
2149 let then_else_common_type = Utf8;
2150 let expected = cast_helper(
2151 case.clone(),
2152 &case_when_common_type,
2153 &then_else_common_type,
2154 &schema,
2155 );
2156 let actual = coerce_case_expression(case, &schema)?;
2157 assert_eq!(expected, actual);
2158
2159 let case = Case {
2160 expr: Some(Box::new(col("string"))),
2161 when_then_expr: vec![
2162 (Box::new(col("float")), Box::new(col("integer"))),
2163 (Box::new(col("integer")), Box::new(col("float"))),
2164 (Box::new(col("string")), Box::new(col("string"))),
2165 ],
2166 else_expr: Some(Box::new(col("string"))),
2167 };
2168 let case_when_common_type = Utf8;
2169 let then_else_common_type = Utf8;
2170 let expected = cast_helper(
2171 case.clone(),
2172 &case_when_common_type,
2173 &then_else_common_type,
2174 &schema,
2175 );
2176 let actual = coerce_case_expression(case, &schema)?;
2177 assert_eq!(expected, actual);
2178
2179 let case = Case {
2180 expr: Some(Box::new(col("interval"))),
2181 when_then_expr: vec![
2182 (Box::new(col("float")), Box::new(col("integer"))),
2183 (Box::new(col("binary")), Box::new(col("float"))),
2184 (Box::new(col("string")), Box::new(col("string"))),
2185 ],
2186 else_expr: Some(Box::new(col("string"))),
2187 };
2188 let err = coerce_case_expression(case, &schema).unwrap_err();
2189 assert_snapshot!(
2190 err.strip_backtrace(),
2191 @"Error during planning: Failed to coerce case (Interval(MonthDayNano)) and when ([Float32, Binary, Utf8]) to common types in CASE WHEN expression"
2192 );
2193
2194 let case = Case {
2195 expr: Some(Box::new(col("string"))),
2196 when_then_expr: vec![
2197 (Box::new(col("float")), Box::new(col("date"))),
2198 (Box::new(col("string")), Box::new(col("float"))),
2199 (Box::new(col("string")), Box::new(col("binary"))),
2200 ],
2201 else_expr: Some(Box::new(col("timestamp"))),
2202 };
2203 let err = coerce_case_expression(case, &schema).unwrap_err();
2204 assert_snapshot!(
2205 err.strip_backtrace(),
2206 @"Error during planning: Failed to coerce then ([Date32, Float32, Binary]) and else (Some(Timestamp(Nanosecond, None))) to common types in CASE WHEN expression"
2207 );
2208
2209 Ok(())
2210 }
2211
2212 macro_rules! test_case_expression {
2213 ($expr:expr, $when_then:expr, $case_when_type:expr, $then_else_type:expr, $schema:expr) => {
2214 let case = Case {
2215 expr: $expr.map(|e| Box::new(col(e))),
2216 when_then_expr: $when_then,
2217 else_expr: None,
2218 };
2219
2220 let expected =
2221 cast_helper(case.clone(), &$case_when_type, &$then_else_type, &$schema);
2222
2223 let actual = coerce_case_expression(case, &$schema)?;
2224 assert_eq!(expected, actual);
2225 };
2226 }
2227
2228 #[test]
2229 fn tes_case_when_list() -> Result<()> {
2230 let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
2231 let schema = Arc::new(DFSchema::from_unqualified_fields(
2232 vec![
2233 Field::new(
2234 "large_list",
2235 DataType::LargeList(Arc::clone(&inner_field)),
2236 true,
2237 ),
2238 Field::new(
2239 "fixed_list",
2240 DataType::FixedSizeList(Arc::clone(&inner_field), 3),
2241 true,
2242 ),
2243 Field::new("list", DataType::List(inner_field), true),
2244 ]
2245 .into(),
2246 std::collections::HashMap::new(),
2247 )?);
2248
2249 test_case_expression!(
2250 Some("list"),
2251 vec![(Box::new(col("large_list")), Box::new(lit("1")))],
2252 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2253 Utf8,
2254 schema
2255 );
2256
2257 test_case_expression!(
2258 Some("large_list"),
2259 vec![(Box::new(col("list")), Box::new(lit("1")))],
2260 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2261 Utf8,
2262 schema
2263 );
2264
2265 test_case_expression!(
2266 Some("list"),
2267 vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
2268 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2269 Utf8,
2270 schema
2271 );
2272
2273 test_case_expression!(
2274 Some("fixed_list"),
2275 vec![(Box::new(col("list")), Box::new(lit("1")))],
2276 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2277 Utf8,
2278 schema
2279 );
2280
2281 test_case_expression!(
2282 Some("fixed_list"),
2283 vec![(Box::new(col("large_list")), Box::new(lit("1")))],
2284 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2285 Utf8,
2286 schema
2287 );
2288
2289 test_case_expression!(
2290 Some("large_list"),
2291 vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
2292 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2293 Utf8,
2294 schema
2295 );
2296 Ok(())
2297 }
2298
2299 #[test]
2300 fn test_then_else_list() -> Result<()> {
2301 let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
2302 let schema = Arc::new(DFSchema::from_unqualified_fields(
2303 vec![
2304 Field::new("boolean", DataType::Boolean, true),
2305 Field::new(
2306 "large_list",
2307 DataType::LargeList(Arc::clone(&inner_field)),
2308 true,
2309 ),
2310 Field::new(
2311 "fixed_list",
2312 DataType::FixedSizeList(Arc::clone(&inner_field), 3),
2313 true,
2314 ),
2315 Field::new("list", DataType::List(inner_field), true),
2316 ]
2317 .into(),
2318 std::collections::HashMap::new(),
2319 )?);
2320
2321 test_case_expression!(
2323 None::<String>,
2324 vec![
2325 (Box::new(col("boolean")), Box::new(col("large_list"))),
2326 (Box::new(col("boolean")), Box::new(col("list")))
2327 ],
2328 DataType::Boolean,
2329 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2330 schema
2331 );
2332
2333 test_case_expression!(
2334 None::<String>,
2335 vec![
2336 (Box::new(col("boolean")), Box::new(col("list"))),
2337 (Box::new(col("boolean")), Box::new(col("large_list")))
2338 ],
2339 DataType::Boolean,
2340 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2341 schema
2342 );
2343
2344 test_case_expression!(
2346 None::<String>,
2347 vec![
2348 (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2349 (Box::new(col("boolean")), Box::new(col("list")))
2350 ],
2351 DataType::Boolean,
2352 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2353 schema
2354 );
2355
2356 test_case_expression!(
2357 None::<String>,
2358 vec![
2359 (Box::new(col("boolean")), Box::new(col("list"))),
2360 (Box::new(col("boolean")), Box::new(col("fixed_list")))
2361 ],
2362 DataType::Boolean,
2363 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2364 schema
2365 );
2366
2367 test_case_expression!(
2369 None::<String>,
2370 vec![
2371 (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2372 (Box::new(col("boolean")), Box::new(col("large_list")))
2373 ],
2374 DataType::Boolean,
2375 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2376 schema
2377 );
2378
2379 test_case_expression!(
2380 None::<String>,
2381 vec![
2382 (Box::new(col("boolean")), Box::new(col("large_list"))),
2383 (Box::new(col("boolean")), Box::new(col("fixed_list")))
2384 ],
2385 DataType::Boolean,
2386 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2387 schema
2388 );
2389 Ok(())
2390 }
2391
2392 #[test]
2393 fn test_map_with_diff_name() -> Result<()> {
2394 let mut builder = SchemaBuilder::new();
2395 builder.push(Field::new("key", Utf8, false));
2396 builder.push(Field::new("value", DataType::Float64, true));
2397 let struct_fields = builder.finish().fields;
2398
2399 let fields =
2400 Field::new("entries", DataType::Struct(struct_fields.clone()), false);
2401 let map_type_entries = DataType::Map(Arc::new(fields), false);
2402
2403 let fields = Field::new("key_value", DataType::Struct(struct_fields), false);
2404 let may_type_cutsom = DataType::Map(Arc::new(fields), false);
2405
2406 let expr = col("a").eq(cast(col("a"), may_type_cutsom));
2407 let empty = empty_with_type(map_type_entries);
2408 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2409
2410 assert_analyzed_plan_eq!(
2411 plan,
2412 @r#"
2413 Projection: a = CAST(CAST(a AS Map(Field { name: "key_value", data_type: Struct([Field { name: "key", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "value", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false)) AS Map(Field { name: "entries", data_type: Struct([Field { name: "key", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "value", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false))
2414 EmptyRelation
2415 "#
2416 )
2417 }
2418
2419 #[test]
2420 fn interval_plus_timestamp() -> Result<()> {
2421 let expr = Expr::BinaryExpr(BinaryExpr::new(
2423 Box::new(lit(ScalarValue::IntervalYearMonth(Some(12)))),
2424 Operator::Plus,
2425 Box::new(cast(
2426 lit("2000-01-01T00:00:00"),
2427 DataType::Timestamp(TimeUnit::Nanosecond, None),
2428 )),
2429 ));
2430 let empty = empty();
2431 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2432
2433 assert_analyzed_plan_eq!(
2434 plan,
2435 @r#"
2436 Projection: IntervalYearMonth("12") + CAST(Utf8("2000-01-01T00:00:00") AS Timestamp(Nanosecond, None))
2437 EmptyRelation
2438 "#
2439 )
2440 }
2441
2442 #[test]
2443 fn timestamp_subtract_timestamp() -> Result<()> {
2444 let expr = Expr::BinaryExpr(BinaryExpr::new(
2445 Box::new(cast(
2446 lit("1998-03-18"),
2447 DataType::Timestamp(TimeUnit::Nanosecond, None),
2448 )),
2449 Operator::Minus,
2450 Box::new(cast(
2451 lit("1998-03-18"),
2452 DataType::Timestamp(TimeUnit::Nanosecond, None),
2453 )),
2454 ));
2455 let empty = empty();
2456 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2457
2458 assert_analyzed_plan_eq!(
2459 plan,
2460 @r#"
2461 Projection: CAST(Utf8("1998-03-18") AS Timestamp(Nanosecond, None)) - CAST(Utf8("1998-03-18") AS Timestamp(Nanosecond, None))
2462 EmptyRelation
2463 "#
2464 )
2465 }
2466
2467 #[test]
2468 fn in_subquery_cast_subquery() -> Result<()> {
2469 let empty_int32 = empty_with_type(DataType::Int32);
2470 let empty_int64 = empty_with_type(DataType::Int64);
2471
2472 let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2473 Box::new(col("a")),
2474 Subquery {
2475 subquery: empty_int32,
2476 outer_ref_columns: vec![],
2477 spans: Spans::new(),
2478 },
2479 false,
2480 ));
2481 let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int64)?);
2482 assert_analyzed_plan_eq!(
2485 plan,
2486 @r"
2487 Filter: a IN (<subquery>)
2488 Subquery:
2489 Projection: CAST(a AS Int64)
2490 EmptyRelation
2491 EmptyRelation
2492 "
2493 )
2494 }
2495
2496 #[test]
2497 fn in_subquery_cast_expr() -> Result<()> {
2498 let empty_int32 = empty_with_type(DataType::Int32);
2499 let empty_int64 = empty_with_type(DataType::Int64);
2500
2501 let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2502 Box::new(col("a")),
2503 Subquery {
2504 subquery: empty_int64,
2505 outer_ref_columns: vec![],
2506 spans: Spans::new(),
2507 },
2508 false,
2509 ));
2510 let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int32)?);
2511
2512 assert_analyzed_plan_eq!(
2514 plan,
2515 @r"
2516 Filter: CAST(a AS Int64) IN (<subquery>)
2517 Subquery:
2518 EmptyRelation
2519 EmptyRelation
2520 "
2521 )
2522 }
2523
2524 #[test]
2525 fn in_subquery_cast_all() -> Result<()> {
2526 let empty_inside = empty_with_type(DataType::Decimal128(10, 5));
2527 let empty_outside = empty_with_type(DataType::Decimal128(8, 8));
2528
2529 let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2530 Box::new(col("a")),
2531 Subquery {
2532 subquery: empty_inside,
2533 outer_ref_columns: vec![],
2534 spans: Spans::new(),
2535 },
2536 false,
2537 ));
2538 let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_outside)?);
2539
2540 assert_analyzed_plan_eq!(
2542 plan,
2543 @r"
2544 Filter: CAST(a AS Decimal128(13, 8)) IN (<subquery>)
2545 Subquery:
2546 Projection: CAST(a AS Decimal128(13, 8))
2547 EmptyRelation
2548 EmptyRelation
2549 "
2550 )
2551 }
2552}