1use super::{Column, Literal};
19use crate::expressions::case::ResultState::{Complete, Empty, Partial};
20use crate::expressions::try_cast;
21use crate::PhysicalExpr;
22use arrow::array::*;
23use arrow::compute::kernels::zip::zip;
24use arrow::compute::{
25 is_not_null, not, nullif, prep_null_mask_filter, FilterBuilder, FilterPredicate,
26 SlicesIterator,
27};
28use arrow::datatypes::{DataType, Schema, UInt32Type, UnionMode};
29use arrow::error::ArrowError;
30use datafusion_common::cast::as_boolean_array;
31use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
32use datafusion_common::{
33 exec_err, internal_datafusion_err, internal_err, DataFusionError, HashMap, HashSet,
34 Result, ScalarValue,
35};
36use datafusion_expr::ColumnarValue;
37use datafusion_physical_expr_common::datum::compare_with_eq;
38use itertools::Itertools;
39use std::borrow::Cow;
40use std::fmt::{Debug, Formatter};
41use std::hash::Hash;
42use std::{any::Any, sync::Arc};
43
44type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
45
46#[derive(Debug, Hash, PartialEq, Eq)]
47enum EvalMethod {
48 NoExpression(ProjectedCaseBody),
53 WithExpression(ProjectedCaseBody),
59 InfallibleExprOrNull,
65 ScalarOrScalar,
70 ExpressionOrExpression(ProjectedCaseBody),
75}
76
77#[derive(Debug, Hash, PartialEq, Eq)]
80struct CaseBody {
81 expr: Option<Arc<dyn PhysicalExpr>>,
83 when_then_expr: Vec<WhenThen>,
85 else_expr: Option<Arc<dyn PhysicalExpr>>,
87}
88
89impl CaseBody {
90 fn project(&self) -> Result<ProjectedCaseBody> {
92 let mut used_column_indices = HashSet::<usize>::new();
94 let mut collect_column_indices = |expr: &Arc<dyn PhysicalExpr>| {
95 expr.apply(|expr| {
96 if let Some(column) = expr.as_any().downcast_ref::<Column>() {
97 used_column_indices.insert(column.index());
98 }
99 Ok(TreeNodeRecursion::Continue)
100 })
101 .expect("Closure cannot fail");
102 };
103
104 if let Some(e) = &self.expr {
105 collect_column_indices(e);
106 }
107 self.when_then_expr.iter().for_each(|(w, t)| {
108 collect_column_indices(w);
109 collect_column_indices(t);
110 });
111 if let Some(e) = &self.else_expr {
112 collect_column_indices(e);
113 }
114
115 let column_index_map = used_column_indices
117 .iter()
118 .enumerate()
119 .map(|(projected, original)| (*original, projected))
120 .collect::<HashMap<usize, usize>>();
121
122 let project = |expr: &Arc<dyn PhysicalExpr>| -> Result<Arc<dyn PhysicalExpr>> {
125 Arc::clone(expr)
126 .transform_down(|e| {
127 if let Some(column) = e.as_any().downcast_ref::<Column>() {
128 let original = column.index();
129 let projected = *column_index_map.get(&original).unwrap();
130 if projected != original {
131 return Ok(Transformed::yes(Arc::new(Column::new(
132 column.name(),
133 projected,
134 ))));
135 }
136 }
137 Ok(Transformed::no(e))
138 })
139 .map(|t| t.data)
140 };
141
142 let projected_body = CaseBody {
143 expr: self.expr.as_ref().map(project).transpose()?,
144 when_then_expr: self
145 .when_then_expr
146 .iter()
147 .map(|(e, t)| Ok((project(e)?, project(t)?)))
148 .collect::<Result<Vec<_>>>()?,
149 else_expr: self.else_expr.as_ref().map(project).transpose()?,
150 };
151
152 let projection = column_index_map
154 .iter()
155 .sorted_by_key(|(_, v)| **v)
156 .map(|(k, _)| *k)
157 .collect::<Vec<_>>();
158
159 Ok(ProjectedCaseBody {
160 projection,
161 body: projected_body,
162 })
163 }
164}
165
166#[derive(Debug, Hash, PartialEq, Eq)]
194struct ProjectedCaseBody {
195 projection: Vec<usize>,
196 body: CaseBody,
197}
198
199#[derive(Debug, Hash, PartialEq, Eq)]
217pub struct CaseExpr {
218 body: CaseBody,
220 eval_method: EvalMethod,
222}
223
224impl std::fmt::Display for CaseExpr {
225 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
226 write!(f, "CASE ")?;
227 if let Some(e) = &self.body.expr {
228 write!(f, "{e} ")?;
229 }
230 for (w, t) in &self.body.when_then_expr {
231 write!(f, "WHEN {w} THEN {t} ")?;
232 }
233 if let Some(e) = &self.body.else_expr {
234 write!(f, "ELSE {e} ")?;
235 }
236 write!(f, "END")
237 }
238}
239
240fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) -> bool {
246 expr.as_any().is::<Column>()
247}
248
249fn create_filter(predicate: &BooleanArray, optimize: bool) -> FilterPredicate {
251 let mut filter_builder = FilterBuilder::new(predicate);
252 if optimize {
253 filter_builder = filter_builder.optimize();
255 }
256 filter_builder.build()
257}
258
259fn multiple_arrays(data_type: &DataType) -> bool {
260 match data_type {
261 DataType::Struct(fields) => {
262 fields.len() > 1
263 || fields.len() == 1 && multiple_arrays(fields[0].data_type())
264 }
265 DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
266 _ => false,
267 }
268}
269
270fn filter_record_batch(
273 record_batch: &RecordBatch,
274 filter: &FilterPredicate,
275) -> std::result::Result<RecordBatch, ArrowError> {
276 let filtered_columns = record_batch
277 .columns()
278 .iter()
279 .map(|a| filter_array(a, filter))
280 .collect::<std::result::Result<Vec<_>, _>>()?;
281 unsafe {
287 Ok(RecordBatch::new_unchecked(
288 record_batch.schema(),
289 filtered_columns,
290 filter.count(),
291 ))
292 }
293}
294
295#[inline(always)]
300fn filter_array(
301 array: &dyn Array,
302 filter: &FilterPredicate,
303) -> std::result::Result<ArrayRef, ArrowError> {
304 filter.filter(array)
305}
306
307fn merge(
308 mask: &BooleanArray,
309 truthy: ColumnarValue,
310 falsy: ColumnarValue,
311) -> std::result::Result<ArrayRef, ArrowError> {
312 let (truthy, truthy_is_scalar) = match truthy {
313 ColumnarValue::Array(a) => (a, false),
314 ColumnarValue::Scalar(s) => (s.to_array()?, true),
315 };
316 let (falsy, falsy_is_scalar) = match falsy {
317 ColumnarValue::Array(a) => (a, false),
318 ColumnarValue::Scalar(s) => (s.to_array()?, true),
319 };
320
321 if truthy_is_scalar && falsy_is_scalar {
322 return zip(mask, &Scalar::new(truthy), &Scalar::new(falsy));
323 }
324
325 let falsy = falsy.to_data();
326 let truthy = truthy.to_data();
327
328 let mut mutable = MutableArrayData::new(vec![&truthy, &falsy], false, truthy.len());
329
330 let mut filled = 0;
335 let mut falsy_offset = 0;
336 let mut truthy_offset = 0;
337
338 SlicesIterator::new(mask).for_each(|(start, end)| {
339 if start > filled {
341 if falsy_is_scalar {
342 for _ in filled..start {
343 mutable.extend(1, 0, 1);
345 }
346 } else {
347 let falsy_length = start - filled;
348 let falsy_end = falsy_offset + falsy_length;
349 mutable.extend(1, falsy_offset, falsy_end);
350 falsy_offset = falsy_end;
351 }
352 }
353 if truthy_is_scalar {
355 for _ in start..end {
356 mutable.extend(0, 0, 1);
358 }
359 } else {
360 let truthy_length = end - start;
361 let truthy_end = truthy_offset + truthy_length;
362 mutable.extend(0, truthy_offset, truthy_end);
363 truthy_offset = truthy_end;
364 }
365 filled = end;
366 });
367 if filled < mask.len() {
369 if falsy_is_scalar {
370 for _ in filled..mask.len() {
371 mutable.extend(1, 0, 1);
373 }
374 } else {
375 let falsy_length = mask.len() - filled;
376 let falsy_end = falsy_offset + falsy_length;
377 mutable.extend(1, falsy_offset, falsy_end);
378 }
379 }
380
381 let data = mutable.freeze();
382 Ok(make_array(data))
383}
384
385fn merge_n(values: &[ArrayData], indices: &[PartialResultIndex]) -> Result<ArrayRef> {
438 #[cfg(debug_assertions)]
439 for ix in indices {
440 if let Some(index) = ix.index() {
441 assert!(
442 index < values.len(),
443 "Index out of bounds: {} >= {}",
444 index,
445 values.len()
446 );
447 }
448 }
449
450 let data_refs = values.iter().collect();
451 let mut mutable = MutableArrayData::new(data_refs, true, indices.len());
452
453 let mut take_offsets = vec![0; values.len() + 1];
457 let mut start_row_ix = 0;
458 loop {
459 let array_ix = indices[start_row_ix];
460
461 let mut end_row_ix = start_row_ix + 1;
463 while end_row_ix < indices.len() && indices[end_row_ix] == array_ix {
464 end_row_ix += 1;
465 }
466 let slice_length = end_row_ix - start_row_ix;
467
468 match array_ix.index() {
470 None => mutable.extend_nulls(slice_length),
471 Some(index) => {
472 let start_offset = take_offsets[index];
473 let end_offset = start_offset + slice_length;
474 mutable.extend(index, start_offset, end_offset);
475 take_offsets[index] = end_offset;
476 }
477 }
478
479 if end_row_ix == indices.len() {
480 break;
481 } else {
482 start_row_ix = end_row_ix;
484 }
485 }
486
487 Ok(make_array(mutable.freeze()))
488}
489
490#[derive(Copy, Clone, PartialEq, Eq)]
495struct PartialResultIndex {
496 index: u32,
497}
498
499const NONE_VALUE: u32 = u32::MAX;
500
501impl PartialResultIndex {
502 fn none() -> Self {
504 Self { index: NONE_VALUE }
505 }
506
507 fn zero() -> Self {
508 Self { index: 0 }
509 }
510
511 fn try_new(index: usize) -> Result<Self> {
516 let Ok(index) = u32::try_from(index) else {
517 return internal_err!("Partial result index exceeds limit");
518 };
519
520 if index == NONE_VALUE {
521 return internal_err!("Partial result index exceeds limit");
522 }
523
524 Ok(Self { index })
525 }
526
527 fn is_none(&self) -> bool {
529 self.index == NONE_VALUE
530 }
531
532 fn index(&self) -> Option<usize> {
534 if self.is_none() {
535 None
536 } else {
537 Some(self.index as usize)
538 }
539 }
540}
541
542impl Debug for PartialResultIndex {
543 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
544 if self.is_none() {
545 write!(f, "null")
546 } else {
547 write!(f, "{}", self.index)
548 }
549 }
550}
551
552enum ResultState {
553 Empty,
555 Partial {
557 arrays: Vec<ArrayData>,
560 indices: Vec<PartialResultIndex>,
562 },
563 Complete(ColumnarValue),
566}
567
568struct ResultBuilder {
578 data_type: DataType,
579 row_count: usize,
581 state: ResultState,
582}
583
584impl ResultBuilder {
585 fn new(data_type: &DataType, row_count: usize) -> Self {
589 Self {
590 data_type: data_type.clone(),
591 row_count,
592 state: Empty,
593 }
594 }
595
596 fn add_branch_result(
629 &mut self,
630 row_indices: &ArrayRef,
631 value: ColumnarValue,
632 ) -> Result<()> {
633 match value {
634 ColumnarValue::Array(a) => {
635 if a.len() != row_indices.len() {
636 internal_err!("Array length must match row indices length")
637 } else if row_indices.len() == self.row_count {
638 self.set_complete_result(ColumnarValue::Array(a))
639 } else {
640 self.add_partial_result(row_indices, a.to_data())
641 }
642 }
643 ColumnarValue::Scalar(s) => {
644 if row_indices.len() == self.row_count {
645 self.set_complete_result(ColumnarValue::Scalar(s))
646 } else {
647 self.add_partial_result(
648 row_indices,
649 s.to_array_of_size(row_indices.len())?.to_data(),
650 )
651 }
652 }
653 }
654 }
655
656 fn add_partial_result(
662 &mut self,
663 row_indices: &ArrayRef,
664 row_values: ArrayData,
665 ) -> Result<()> {
666 if row_indices.null_count() != 0 {
667 return internal_err!("Row indices must not contain nulls");
668 }
669
670 match &mut self.state {
671 Empty => {
672 let array_index = PartialResultIndex::zero();
673 let mut indices = vec![PartialResultIndex::none(); self.row_count];
674 for row_ix in row_indices.as_primitive::<UInt32Type>().values().iter() {
675 indices[*row_ix as usize] = array_index;
676 }
677
678 self.state = Partial {
679 arrays: vec![row_values],
680 indices,
681 };
682
683 Ok(())
684 }
685 Partial { arrays, indices } => {
686 let array_index = PartialResultIndex::try_new(arrays.len())?;
687
688 arrays.push(row_values);
689
690 for row_ix in row_indices.as_primitive::<UInt32Type>().values().iter() {
691 #[cfg(debug_assertions)]
695 if !indices[*row_ix as usize].is_none() {
696 return internal_err!("Duplicate value for row {}", *row_ix);
697 }
698
699 indices[*row_ix as usize] = array_index;
700 }
701 Ok(())
702 }
703 Complete(_) => internal_err!(
704 "Cannot add a partial result when complete result is already set"
705 ),
706 }
707 }
708
709 fn set_complete_result(&mut self, value: ColumnarValue) -> Result<()> {
715 match &self.state {
716 Empty => {
717 self.state = Complete(value);
718 Ok(())
719 }
720 Partial { .. } => {
721 internal_err!(
722 "Cannot set a complete result when there are already partial results"
723 )
724 }
725 Complete(_) => internal_err!("Complete result already set"),
726 }
727 }
728
729 fn finish(self) -> Result<ColumnarValue> {
731 match self.state {
732 Empty => {
733 Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
737 &self.data_type,
738 )?))
739 }
740 Partial { arrays, indices } => {
741 Ok(ColumnarValue::Array(merge_n(&arrays, &indices)?))
743 }
744 Complete(v) => {
745 Ok(v)
747 }
748 }
749 }
750}
751
752impl CaseExpr {
753 pub fn try_new(
755 expr: Option<Arc<dyn PhysicalExpr>>,
756 when_then_expr: Vec<WhenThen>,
757 else_expr: Option<Arc<dyn PhysicalExpr>>,
758 ) -> Result<Self> {
759 let else_expr = match &else_expr {
762 Some(e) => match e.as_any().downcast_ref::<Literal>() {
763 Some(lit) if lit.value().is_null() => None,
764 _ => else_expr,
765 },
766 _ => else_expr,
767 };
768
769 if when_then_expr.is_empty() {
770 return exec_err!("There must be at least one WHEN clause");
771 }
772
773 let body = CaseBody {
774 expr,
775 when_then_expr,
776 else_expr,
777 };
778
779 let eval_method = if body.expr.is_some() {
780 EvalMethod::WithExpression(body.project()?)
781 } else if body.when_then_expr.len() == 1
782 && is_cheap_and_infallible(&(body.when_then_expr[0].1))
783 && body.else_expr.is_none()
784 {
785 EvalMethod::InfallibleExprOrNull
786 } else if body.when_then_expr.len() == 1
787 && body.when_then_expr[0].1.as_any().is::<Literal>()
788 && body.else_expr.is_some()
789 && body.else_expr.as_ref().unwrap().as_any().is::<Literal>()
790 {
791 EvalMethod::ScalarOrScalar
792 } else if body.when_then_expr.len() == 1 && body.else_expr.is_some() {
793 EvalMethod::ExpressionOrExpression(body.project()?)
794 } else {
795 EvalMethod::NoExpression(body.project()?)
796 };
797
798 Ok(Self { body, eval_method })
799 }
800
801 pub fn expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
803 self.body.expr.as_ref()
804 }
805
806 pub fn when_then_expr(&self) -> &[WhenThen] {
808 &self.body.when_then_expr
809 }
810
811 pub fn else_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
813 self.body.else_expr.as_ref()
814 }
815}
816
817impl CaseBody {
818 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
819 let mut data_type = DataType::Null;
822 for i in 0..self.when_then_expr.len() {
823 data_type = self.when_then_expr[i].1.data_type(input_schema)?;
824 if !data_type.equals_datatype(&DataType::Null) {
825 break;
826 }
827 }
828 if data_type.equals_datatype(&DataType::Null) {
830 if let Some(e) = &self.else_expr {
831 data_type = e.data_type(input_schema)?;
832 }
833 }
834
835 Ok(data_type)
836 }
837
838 fn case_when_with_expr(
840 &self,
841 batch: &RecordBatch,
842 return_type: &DataType,
843 ) -> Result<ColumnarValue> {
844 let mut result_builder = ResultBuilder::new(return_type, batch.num_rows());
845
846 let mut remainder_rows: ArrayRef =
848 Arc::new(UInt32Array::from_iter_values(0..batch.num_rows() as u32));
849 let mut remainder_batch = Cow::Borrowed(batch);
851
852 let mut base_values = self
854 .expr
855 .as_ref()
856 .unwrap()
857 .evaluate(batch)?
858 .into_array(batch.num_rows())?;
859
860 if base_values.null_count() > 0 {
865 let base_not_nulls = is_not_null(base_values.as_ref())?;
869 let base_all_null = base_values.null_count() == remainder_batch.num_rows();
870
871 if let Some(e) = &self.else_expr {
874 let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
875
876 if base_all_null {
877 let nulls_value = expr.evaluate(&remainder_batch)?;
879 result_builder.add_branch_result(&remainder_rows, nulls_value)?;
880 } else {
881 let nulls_filter = create_filter(¬(&base_not_nulls)?, true);
883 let nulls_batch =
884 filter_record_batch(&remainder_batch, &nulls_filter)?;
885 let nulls_rows = filter_array(&remainder_rows, &nulls_filter)?;
886 let nulls_value = expr.evaluate(&nulls_batch)?;
887 result_builder.add_branch_result(&nulls_rows, nulls_value)?;
888 }
889 }
890
891 if base_all_null {
893 return result_builder.finish();
894 }
895
896 let not_null_filter = create_filter(&base_not_nulls, true);
898 remainder_batch =
899 Cow::Owned(filter_record_batch(&remainder_batch, ¬_null_filter)?);
900 remainder_rows = filter_array(&remainder_rows, ¬_null_filter)?;
901 base_values = filter_array(&base_values, ¬_null_filter)?;
902 }
903
904 let base_value_is_nested = base_values.data_type().is_nested();
907
908 for i in 0..self.when_then_expr.len() {
909 let when_expr = &self.when_then_expr[i].0;
912 let when_value = match when_expr.evaluate(&remainder_batch)? {
913 ColumnarValue::Array(a) => {
914 compare_with_eq(&a, &base_values, base_value_is_nested)
915 }
916 ColumnarValue::Scalar(s) => {
917 compare_with_eq(&s.to_scalar()?, &base_values, base_value_is_nested)
918 }
919 }?;
920
921 let when_true_count = when_value.true_count();
924
925 if when_true_count == 0 {
927 continue;
928 }
929
930 if when_true_count == remainder_batch.num_rows() {
932 let then_expression = &self.when_then_expr[i].1;
933 let then_value = then_expression.evaluate(&remainder_batch)?;
934 result_builder.add_branch_result(&remainder_rows, then_value)?;
935 return result_builder.finish();
936 }
937
938 let then_filter = create_filter(&when_value, true);
944 let then_batch = filter_record_batch(&remainder_batch, &then_filter)?;
945 let then_rows = filter_array(&remainder_rows, &then_filter)?;
946
947 let then_expression = &self.when_then_expr[i].1;
948 let then_value = then_expression.evaluate(&then_batch)?;
949 result_builder.add_branch_result(&then_rows, then_value)?;
950
951 if self.else_expr.is_none() && i == self.when_then_expr.len() - 1 {
954 return result_builder.finish();
955 }
956
957 let next_selection = match when_value.null_count() {
959 0 => not(&when_value),
960 _ => {
961 not(&prep_null_mask_filter(&when_value))
964 }
965 }?;
966 let next_filter = create_filter(&next_selection, true);
967 remainder_batch =
968 Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?);
969 remainder_rows = filter_array(&remainder_rows, &next_filter)?;
970 base_values = filter_array(&base_values, &next_filter)?;
971 }
972
973 if let Some(e) = &self.else_expr {
976 let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
978 let else_value = expr.evaluate(&remainder_batch)?;
979 result_builder.add_branch_result(&remainder_rows, else_value)?;
980 }
981
982 result_builder.finish()
983 }
984
985 fn case_when_no_expr(
987 &self,
988 batch: &RecordBatch,
989 return_type: &DataType,
990 ) -> Result<ColumnarValue> {
991 let mut result_builder = ResultBuilder::new(return_type, batch.num_rows());
992
993 let mut remainder_rows: ArrayRef =
995 Arc::new(UInt32Array::from_iter(0..batch.num_rows() as u32));
996 let mut remainder_batch = Cow::Borrowed(batch);
998
999 for i in 0..self.when_then_expr.len() {
1000 let when_predicate = &self.when_then_expr[i].0;
1003 let when_value = when_predicate
1004 .evaluate(&remainder_batch)?
1005 .into_array(remainder_batch.num_rows())?;
1006 let when_value = as_boolean_array(&when_value).map_err(|_| {
1007 internal_datafusion_err!("WHEN expression did not return a BooleanArray")
1008 })?;
1009
1010 let when_true_count = when_value.true_count();
1013
1014 if when_true_count == 0 {
1016 continue;
1017 }
1018
1019 if when_true_count == remainder_batch.num_rows() {
1021 let then_expression = &self.when_then_expr[i].1;
1022 let then_value = then_expression.evaluate(&remainder_batch)?;
1023 result_builder.add_branch_result(&remainder_rows, then_value)?;
1024 return result_builder.finish();
1025 }
1026
1027 let then_filter = create_filter(when_value, true);
1033 let then_batch = filter_record_batch(&remainder_batch, &then_filter)?;
1034 let then_rows = filter_array(&remainder_rows, &then_filter)?;
1035
1036 let then_expression = &self.when_then_expr[i].1;
1037 let then_value = then_expression.evaluate(&then_batch)?;
1038 result_builder.add_branch_result(&then_rows, then_value)?;
1039
1040 if self.else_expr.is_none() && i == self.when_then_expr.len() - 1 {
1043 return result_builder.finish();
1044 }
1045
1046 let next_selection = match when_value.null_count() {
1048 0 => not(when_value),
1049 _ => {
1050 not(&prep_null_mask_filter(when_value))
1053 }
1054 }?;
1055 let next_filter = create_filter(&next_selection, true);
1056 remainder_batch =
1057 Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?);
1058 remainder_rows = filter_array(&remainder_rows, &next_filter)?;
1059 }
1060
1061 if let Some(e) = &self.else_expr {
1064 let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
1066 let else_value = expr.evaluate(&remainder_batch)?;
1067 result_builder.add_branch_result(&remainder_rows, else_value)?;
1068 }
1069
1070 result_builder.finish()
1071 }
1072
1073 fn expr_or_expr(
1075 &self,
1076 batch: &RecordBatch,
1077 when_value: &BooleanArray,
1078 ) -> Result<ColumnarValue> {
1079 let when_value = match when_value.null_count() {
1080 0 => Cow::Borrowed(when_value),
1081 _ => {
1082 Cow::Owned(prep_null_mask_filter(when_value))
1084 }
1085 };
1086
1087 let optimize_filter = batch.num_columns() > 1
1088 || (batch.num_columns() == 1 && multiple_arrays(batch.column(0).data_type()));
1089
1090 let when_filter = create_filter(&when_value, optimize_filter);
1091 let then_batch = filter_record_batch(batch, &when_filter)?;
1092 let then_value = self.when_then_expr[0].1.evaluate(&then_batch)?;
1093
1094 let else_selection = not(&when_value)?;
1095 let else_filter = create_filter(&else_selection, optimize_filter);
1096 let else_batch = filter_record_batch(batch, &else_filter)?;
1097
1098 let e = self.else_expr.as_ref().unwrap();
1100 let return_type = self.data_type(&batch.schema())?;
1101 let else_expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())
1102 .unwrap_or_else(|_| Arc::clone(e));
1103
1104 let else_value = else_expr.evaluate(&else_batch)?;
1105
1106 Ok(ColumnarValue::Array(merge(
1107 &when_value,
1108 then_value,
1109 else_value,
1110 )?))
1111 }
1112}
1113
1114impl CaseExpr {
1115 fn case_when_with_expr(
1123 &self,
1124 batch: &RecordBatch,
1125 projected: &ProjectedCaseBody,
1126 ) -> Result<ColumnarValue> {
1127 let return_type = self.data_type(&batch.schema())?;
1128 if projected.projection.len() < batch.num_columns() {
1129 let projected_batch = batch.project(&projected.projection)?;
1130 projected
1131 .body
1132 .case_when_with_expr(&projected_batch, &return_type)
1133 } else {
1134 self.body.case_when_with_expr(batch, &return_type)
1135 }
1136 }
1137
1138 fn case_when_no_expr(
1146 &self,
1147 batch: &RecordBatch,
1148 projected: &ProjectedCaseBody,
1149 ) -> Result<ColumnarValue> {
1150 let return_type = self.data_type(&batch.schema())?;
1151 if projected.projection.len() < batch.num_columns() {
1152 let projected_batch = batch.project(&projected.projection)?;
1153 projected
1154 .body
1155 .case_when_no_expr(&projected_batch, &return_type)
1156 } else {
1157 self.body.case_when_no_expr(batch, &return_type)
1158 }
1159 }
1160
1161 fn case_column_or_null(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
1171 let when_expr = &self.body.when_then_expr[0].0;
1172 let then_expr = &self.body.when_then_expr[0].1;
1173
1174 match when_expr.evaluate(batch)? {
1175 ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))) => {
1177 then_expr.evaluate(batch)
1178 }
1179 ColumnarValue::Scalar(_) => {
1181 ScalarValue::try_from(self.data_type(&batch.schema())?)
1183 .map(ColumnarValue::Scalar)
1184 }
1185 ColumnarValue::Array(bit_mask) => {
1187 let bit_mask = bit_mask
1188 .as_any()
1189 .downcast_ref::<BooleanArray>()
1190 .expect("predicate should evaluate to a boolean array");
1191 let bit_mask = match bit_mask.null_count() {
1193 0 => not(bit_mask)?,
1194 _ => not(&prep_null_mask_filter(bit_mask))?,
1195 };
1196 match then_expr.evaluate(batch)? {
1197 ColumnarValue::Array(array) => {
1198 Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?))
1199 }
1200 ColumnarValue::Scalar(_) => {
1201 internal_err!("expression did not evaluate to an array")
1202 }
1203 }
1204 }
1205 }
1206 }
1207
1208 fn scalar_or_scalar(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
1209 let return_type = self.data_type(&batch.schema())?;
1210
1211 let when_value = self.body.when_then_expr[0].0.evaluate(batch)?;
1213 let when_value = when_value.into_array(batch.num_rows())?;
1214 let when_value = as_boolean_array(&when_value).map_err(|_| {
1215 internal_datafusion_err!("WHEN expression did not return a BooleanArray")
1216 })?;
1217
1218 let when_value = match when_value.null_count() {
1220 0 => Cow::Borrowed(when_value),
1221 _ => Cow::Owned(prep_null_mask_filter(when_value)),
1222 };
1223
1224 let then_value = self.body.when_then_expr[0].1.evaluate(batch)?;
1226 let then_value = Scalar::new(then_value.into_array(1)?);
1227
1228 let Some(e) = &self.body.else_expr else {
1229 return internal_err!("expression did not evaluate to an array");
1230 };
1231 let expr = try_cast(Arc::clone(e), &batch.schema(), return_type)?;
1233 let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?);
1234 Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?))
1235 }
1236
1237 fn expr_or_expr(
1238 &self,
1239 batch: &RecordBatch,
1240 projected: &ProjectedCaseBody,
1241 ) -> Result<ColumnarValue> {
1242 let when_value = self.body.when_then_expr[0].0.evaluate(batch)?;
1244 let when_value = when_value.into_array(1)?;
1248 let when_value = as_boolean_array(&when_value).map_err(|e| {
1249 DataFusionError::Context(
1250 "WHEN expression did not return a BooleanArray".to_string(),
1251 Box::new(e),
1252 )
1253 })?;
1254
1255 let true_count = when_value.true_count();
1256 if true_count == when_value.len() {
1257 self.body.when_then_expr[0].1.evaluate(batch)
1259 } else if true_count == 0 {
1260 self.body.else_expr.as_ref().unwrap().evaluate(batch)
1262 } else if projected.projection.len() < batch.num_columns() {
1263 let projected_batch = batch.project(&projected.projection)?;
1266 projected.body.expr_or_expr(&projected_batch, when_value)
1267 } else {
1268 self.body.expr_or_expr(batch, when_value)
1270 }
1271 }
1272}
1273
1274impl PhysicalExpr for CaseExpr {
1275 fn as_any(&self) -> &dyn Any {
1277 self
1278 }
1279
1280 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
1281 self.body.data_type(input_schema)
1282 }
1283
1284 fn nullable(&self, input_schema: &Schema) -> Result<bool> {
1285 let then_nullable = self
1287 .body
1288 .when_then_expr
1289 .iter()
1290 .map(|(_, t)| t.nullable(input_schema))
1291 .collect::<Result<Vec<_>>>()?;
1292 if then_nullable.contains(&true) {
1293 Ok(true)
1294 } else if let Some(e) = &self.body.else_expr {
1295 e.nullable(input_schema)
1296 } else {
1297 Ok(true)
1300 }
1301 }
1302
1303 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
1304 match &self.eval_method {
1305 EvalMethod::WithExpression(p) => {
1306 self.case_when_with_expr(batch, p)
1309 }
1310 EvalMethod::NoExpression(p) => {
1311 self.case_when_no_expr(batch, p)
1314 }
1315 EvalMethod::InfallibleExprOrNull => {
1316 self.case_column_or_null(batch)
1318 }
1319 EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch),
1320 EvalMethod::ExpressionOrExpression(p) => self.expr_or_expr(batch, p),
1321 }
1322 }
1323
1324 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
1325 let mut children = vec![];
1326 if let Some(expr) = &self.body.expr {
1327 children.push(expr)
1328 }
1329 self.body.when_then_expr.iter().for_each(|(cond, value)| {
1330 children.push(cond);
1331 children.push(value);
1332 });
1333
1334 if let Some(else_expr) = &self.body.else_expr {
1335 children.push(else_expr)
1336 }
1337 children
1338 }
1339
1340 fn with_new_children(
1342 self: Arc<Self>,
1343 children: Vec<Arc<dyn PhysicalExpr>>,
1344 ) -> Result<Arc<dyn PhysicalExpr>> {
1345 if children.len() != self.children().len() {
1346 internal_err!("CaseExpr: Wrong number of children")
1347 } else {
1348 let (expr, when_then_expr, else_expr) =
1349 match (self.expr().is_some(), self.body.else_expr.is_some()) {
1350 (true, true) => (
1351 Some(&children[0]),
1352 &children[1..children.len() - 1],
1353 Some(&children[children.len() - 1]),
1354 ),
1355 (true, false) => {
1356 (Some(&children[0]), &children[1..children.len()], None)
1357 }
1358 (false, true) => (
1359 None,
1360 &children[0..children.len() - 1],
1361 Some(&children[children.len() - 1]),
1362 ),
1363 (false, false) => (None, &children[0..children.len()], None),
1364 };
1365 Ok(Arc::new(CaseExpr::try_new(
1366 expr.cloned(),
1367 when_then_expr.iter().cloned().tuples().collect(),
1368 else_expr.cloned(),
1369 )?))
1370 }
1371 }
1372
1373 fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1374 write!(f, "CASE ")?;
1375 if let Some(e) = &self.body.expr {
1376 e.fmt_sql(f)?;
1377 write!(f, " ")?;
1378 }
1379
1380 for (w, t) in &self.body.when_then_expr {
1381 write!(f, "WHEN ")?;
1382 w.fmt_sql(f)?;
1383 write!(f, " THEN ")?;
1384 t.fmt_sql(f)?;
1385 write!(f, " ")?;
1386 }
1387
1388 if let Some(e) = &self.body.else_expr {
1389 write!(f, "ELSE ")?;
1390 e.fmt_sql(f)?;
1391 write!(f, " ")?;
1392 }
1393 write!(f, "END")
1394 }
1395}
1396
1397pub fn case(
1399 expr: Option<Arc<dyn PhysicalExpr>>,
1400 when_thens: Vec<WhenThen>,
1401 else_expr: Option<Arc<dyn PhysicalExpr>>,
1402) -> Result<Arc<dyn PhysicalExpr>> {
1403 Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?))
1404}
1405
1406#[cfg(test)]
1407mod tests {
1408 use super::*;
1409
1410 use crate::expressions::{binary, cast, col, lit, BinaryExpr};
1411 use arrow::buffer::Buffer;
1412 use arrow::datatypes::DataType::Float64;
1413 use arrow::datatypes::Field;
1414 use datafusion_common::cast::{as_float64_array, as_int32_array};
1415 use datafusion_common::plan_err;
1416 use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
1417 use datafusion_expr::type_coercion::binary::comparison_coercion;
1418 use datafusion_expr::Operator;
1419 use datafusion_physical_expr_common::physical_expr::fmt_sql;
1420
1421 #[test]
1422 fn case_with_expr() -> Result<()> {
1423 let batch = case_test_batch()?;
1424 let schema = batch.schema();
1425
1426 let when1 = lit("foo");
1428 let then1 = lit(123i32);
1429 let when2 = lit("bar");
1430 let then2 = lit(456i32);
1431
1432 let expr = generate_case_when_with_type_coercion(
1433 Some(col("a", &schema)?),
1434 vec![(when1, then1), (when2, then2)],
1435 None,
1436 schema.as_ref(),
1437 )?;
1438 let result = expr
1439 .evaluate(&batch)?
1440 .into_array(batch.num_rows())
1441 .expect("Failed to convert to array");
1442 let result = as_int32_array(&result)?;
1443
1444 let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1445
1446 assert_eq!(expected, result);
1447
1448 Ok(())
1449 }
1450
1451 #[test]
1452 fn case_with_expr_else() -> Result<()> {
1453 let batch = case_test_batch()?;
1454 let schema = batch.schema();
1455
1456 let when1 = lit("foo");
1458 let then1 = lit(123i32);
1459 let when2 = lit("bar");
1460 let then2 = lit(456i32);
1461 let else_value = lit(999i32);
1462
1463 let expr = generate_case_when_with_type_coercion(
1464 Some(col("a", &schema)?),
1465 vec![(when1, then1), (when2, then2)],
1466 Some(else_value),
1467 schema.as_ref(),
1468 )?;
1469 let result = expr
1470 .evaluate(&batch)?
1471 .into_array(batch.num_rows())
1472 .expect("Failed to convert to array");
1473 let result = as_int32_array(&result)?;
1474
1475 let expected =
1476 &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
1477
1478 assert_eq!(expected, result);
1479
1480 Ok(())
1481 }
1482
1483 #[test]
1484 fn case_with_expr_divide_by_zero() -> Result<()> {
1485 let batch = case_test_batch1()?;
1486 let schema = batch.schema();
1487
1488 let when1 = lit(0i32);
1490 let then1 = lit(ScalarValue::Float64(None));
1491 let else_value = binary(
1492 lit(25.0f64),
1493 Operator::Divide,
1494 cast(col("a", &schema)?, &batch.schema(), Float64)?,
1495 &batch.schema(),
1496 )?;
1497
1498 let expr = generate_case_when_with_type_coercion(
1499 Some(col("a", &schema)?),
1500 vec![(when1, then1)],
1501 Some(else_value),
1502 schema.as_ref(),
1503 )?;
1504 let result = expr
1505 .evaluate(&batch)?
1506 .into_array(batch.num_rows())
1507 .expect("Failed to convert to array");
1508 let result =
1509 as_float64_array(&result).expect("failed to downcast to Float64Array");
1510
1511 let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
1512
1513 assert_eq!(expected, result);
1514
1515 Ok(())
1516 }
1517
1518 #[test]
1519 fn case_without_expr() -> Result<()> {
1520 let batch = case_test_batch()?;
1521 let schema = batch.schema();
1522
1523 let when1 = binary(
1525 col("a", &schema)?,
1526 Operator::Eq,
1527 lit("foo"),
1528 &batch.schema(),
1529 )?;
1530 let then1 = lit(123i32);
1531 let when2 = binary(
1532 col("a", &schema)?,
1533 Operator::Eq,
1534 lit("bar"),
1535 &batch.schema(),
1536 )?;
1537 let then2 = lit(456i32);
1538
1539 let expr = generate_case_when_with_type_coercion(
1540 None,
1541 vec![(when1, then1), (when2, then2)],
1542 None,
1543 schema.as_ref(),
1544 )?;
1545 let result = expr
1546 .evaluate(&batch)?
1547 .into_array(batch.num_rows())
1548 .expect("Failed to convert to array");
1549 let result = as_int32_array(&result)?;
1550
1551 let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1552
1553 assert_eq!(expected, result);
1554
1555 Ok(())
1556 }
1557
1558 #[test]
1559 fn case_with_expr_when_null() -> Result<()> {
1560 let batch = case_test_batch()?;
1561 let schema = batch.schema();
1562
1563 let when1 = lit(ScalarValue::Utf8(None));
1565 let then1 = lit(0i32);
1566 let when2 = col("a", &schema)?;
1567 let then2 = lit(123i32);
1568 let else_value = lit(999i32);
1569
1570 let expr = generate_case_when_with_type_coercion(
1571 Some(col("a", &schema)?),
1572 vec![(when1, then1), (when2, then2)],
1573 Some(else_value),
1574 schema.as_ref(),
1575 )?;
1576 let result = expr
1577 .evaluate(&batch)?
1578 .into_array(batch.num_rows())
1579 .expect("Failed to convert to array");
1580 let result = as_int32_array(&result)?;
1581
1582 let expected =
1583 &Int32Array::from(vec![Some(123), Some(123), Some(999), Some(123)]);
1584
1585 assert_eq!(expected, result);
1586
1587 Ok(())
1588 }
1589
1590 #[test]
1591 fn case_without_expr_divide_by_zero() -> Result<()> {
1592 let batch = case_test_batch1()?;
1593 let schema = batch.schema();
1594
1595 let when1 = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &batch.schema())?;
1597 let then1 = binary(
1598 lit(25.0f64),
1599 Operator::Divide,
1600 cast(col("a", &schema)?, &batch.schema(), Float64)?,
1601 &batch.schema(),
1602 )?;
1603 let x = lit(ScalarValue::Float64(None));
1604
1605 let expr = generate_case_when_with_type_coercion(
1606 None,
1607 vec![(when1, then1)],
1608 Some(x),
1609 schema.as_ref(),
1610 )?;
1611 let result = expr
1612 .evaluate(&batch)?
1613 .into_array(batch.num_rows())
1614 .expect("Failed to convert to array");
1615 let result =
1616 as_float64_array(&result).expect("failed to downcast to Float64Array");
1617
1618 let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
1619
1620 assert_eq!(expected, result);
1621
1622 Ok(())
1623 }
1624
1625 fn case_test_batch1() -> Result<RecordBatch> {
1626 let schema = Schema::new(vec![
1627 Field::new("a", DataType::Int32, true),
1628 Field::new("b", DataType::Int32, true),
1629 Field::new("c", DataType::Int32, true),
1630 ]);
1631 let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]);
1632 let b = Int32Array::from(vec![Some(3), None, Some(14), Some(7)]);
1633 let c = Int32Array::from(vec![Some(0), Some(-3), Some(777), None]);
1634 let batch = RecordBatch::try_new(
1635 Arc::new(schema),
1636 vec![Arc::new(a), Arc::new(b), Arc::new(c)],
1637 )?;
1638 Ok(batch)
1639 }
1640
1641 #[test]
1642 fn case_without_expr_else() -> Result<()> {
1643 let batch = case_test_batch()?;
1644 let schema = batch.schema();
1645
1646 let when1 = binary(
1648 col("a", &schema)?,
1649 Operator::Eq,
1650 lit("foo"),
1651 &batch.schema(),
1652 )?;
1653 let then1 = lit(123i32);
1654 let when2 = binary(
1655 col("a", &schema)?,
1656 Operator::Eq,
1657 lit("bar"),
1658 &batch.schema(),
1659 )?;
1660 let then2 = lit(456i32);
1661 let else_value = lit(999i32);
1662
1663 let expr = generate_case_when_with_type_coercion(
1664 None,
1665 vec![(when1, then1), (when2, then2)],
1666 Some(else_value),
1667 schema.as_ref(),
1668 )?;
1669 let result = expr
1670 .evaluate(&batch)?
1671 .into_array(batch.num_rows())
1672 .expect("Failed to convert to array");
1673 let result = as_int32_array(&result)?;
1674
1675 let expected =
1676 &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
1677
1678 assert_eq!(expected, result);
1679
1680 Ok(())
1681 }
1682
1683 #[test]
1684 fn case_with_type_cast() -> Result<()> {
1685 let batch = case_test_batch()?;
1686 let schema = batch.schema();
1687
1688 let when = binary(
1690 col("a", &schema)?,
1691 Operator::Eq,
1692 lit("foo"),
1693 &batch.schema(),
1694 )?;
1695 let then = lit(123.3f64);
1696 let else_value = lit(999i32);
1697
1698 let expr = generate_case_when_with_type_coercion(
1699 None,
1700 vec![(when, then)],
1701 Some(else_value),
1702 schema.as_ref(),
1703 )?;
1704 let result = expr
1705 .evaluate(&batch)?
1706 .into_array(batch.num_rows())
1707 .expect("Failed to convert to array");
1708 let result =
1709 as_float64_array(&result).expect("failed to downcast to Float64Array");
1710
1711 let expected =
1712 &Float64Array::from(vec![Some(123.3), Some(999.0), Some(999.0), Some(999.0)]);
1713
1714 assert_eq!(expected, result);
1715
1716 Ok(())
1717 }
1718
1719 #[test]
1720 fn case_with_matches_and_nulls() -> Result<()> {
1721 let batch = case_test_batch_nulls()?;
1722 let schema = batch.schema();
1723
1724 let when = binary(
1726 col("load4", &schema)?,
1727 Operator::Eq,
1728 lit(1.77f64),
1729 &batch.schema(),
1730 )?;
1731 let then = col("load4", &schema)?;
1732
1733 let expr = generate_case_when_with_type_coercion(
1734 None,
1735 vec![(when, then)],
1736 None,
1737 schema.as_ref(),
1738 )?;
1739 let result = expr
1740 .evaluate(&batch)?
1741 .into_array(batch.num_rows())
1742 .expect("Failed to convert to array");
1743 let result =
1744 as_float64_array(&result).expect("failed to downcast to Float64Array");
1745
1746 let expected =
1747 &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
1748
1749 assert_eq!(expected, result);
1750
1751 Ok(())
1752 }
1753
1754 #[test]
1755 fn case_with_scalar_predicate() -> Result<()> {
1756 let batch = case_test_batch_nulls()?;
1757 let schema = batch.schema();
1758
1759 let when = lit(true);
1761 let then = col("load4", &schema)?;
1762 let expr = generate_case_when_with_type_coercion(
1763 None,
1764 vec![(when, then)],
1765 None,
1766 schema.as_ref(),
1767 )?;
1768
1769 let result = expr
1771 .evaluate(&batch)?
1772 .into_array(batch.num_rows())
1773 .expect("Failed to convert to array");
1774 let result =
1775 as_float64_array(&result).expect("failed to downcast to Float64Array");
1776 let expected = &Float64Array::from(vec![
1777 Some(1.77),
1778 None,
1779 None,
1780 Some(1.78),
1781 None,
1782 Some(1.77),
1783 ]);
1784 assert_eq!(expected, result);
1785
1786 let expected = Float64Array::from(vec![Some(1.1)]);
1788 let batch =
1789 RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(expected.clone())])?;
1790 let result = expr
1791 .evaluate(&batch)?
1792 .into_array(batch.num_rows())
1793 .expect("Failed to convert to array");
1794 let result =
1795 as_float64_array(&result).expect("failed to downcast to Float64Array");
1796 assert_eq!(&expected, result);
1797
1798 Ok(())
1799 }
1800
1801 #[test]
1802 fn case_expr_matches_and_nulls() -> Result<()> {
1803 let batch = case_test_batch_nulls()?;
1804 let schema = batch.schema();
1805
1806 let expr = col("load4", &schema)?;
1808 let when = lit(1.77f64);
1809 let then = col("load4", &schema)?;
1810
1811 let expr = generate_case_when_with_type_coercion(
1812 Some(expr),
1813 vec![(when, then)],
1814 None,
1815 schema.as_ref(),
1816 )?;
1817 let result = expr
1818 .evaluate(&batch)?
1819 .into_array(batch.num_rows())
1820 .expect("Failed to convert to array");
1821 let result =
1822 as_float64_array(&result).expect("failed to downcast to Float64Array");
1823
1824 let expected =
1825 &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
1826
1827 assert_eq!(expected, result);
1828
1829 Ok(())
1830 }
1831
1832 #[test]
1833 fn test_when_null_and_some_cond_else_null() -> Result<()> {
1834 let batch = case_test_batch()?;
1835 let schema = batch.schema();
1836
1837 let when = binary(
1838 Arc::new(Literal::new(ScalarValue::Boolean(None))),
1839 Operator::And,
1840 binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?,
1841 &schema,
1842 )?;
1843 let then = col("a", &schema)?;
1844
1845 let expr = Arc::new(CaseExpr::try_new(None, vec![(when, then)], None)?);
1847 let result = expr
1848 .evaluate(&batch)?
1849 .into_array(batch.num_rows())
1850 .expect("Failed to convert to array");
1851 let result = as_string_array(&result);
1852
1853 assert_eq!(result.logical_null_count(), batch.num_rows());
1855 Ok(())
1856 }
1857
1858 fn case_test_batch() -> Result<RecordBatch> {
1859 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
1860 let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
1861 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
1862 Ok(batch)
1863 }
1864
1865 fn case_test_batch_nulls() -> Result<RecordBatch> {
1868 let load4: Float64Array = vec![
1869 Some(1.77), Some(1.77), Some(1.77), Some(1.78), None, Some(1.77), ]
1876 .into_iter()
1877 .collect();
1878
1879 let null_buffer = Buffer::from([0b00101001u8]);
1880 let load4 = load4
1881 .into_data()
1882 .into_builder()
1883 .null_bit_buffer(Some(null_buffer))
1884 .build()
1885 .unwrap();
1886 let load4: Float64Array = load4.into();
1887
1888 let batch =
1889 RecordBatch::try_from_iter(vec![("load4", Arc::new(load4) as ArrayRef)])?;
1890 Ok(batch)
1891 }
1892
1893 #[test]
1894 fn case_test_incompatible() -> Result<()> {
1895 let batch = case_test_batch()?;
1898 let schema = batch.schema();
1899
1900 let when1 = binary(
1902 col("a", &schema)?,
1903 Operator::Eq,
1904 lit("foo"),
1905 &batch.schema(),
1906 )?;
1907 let then1 = lit(123i32);
1908 let when2 = binary(
1909 col("a", &schema)?,
1910 Operator::Eq,
1911 lit("bar"),
1912 &batch.schema(),
1913 )?;
1914 let then2 = lit(true);
1915
1916 let expr = generate_case_when_with_type_coercion(
1917 None,
1918 vec![(when1, then1), (when2, then2)],
1919 None,
1920 schema.as_ref(),
1921 );
1922 assert!(expr.is_err());
1923
1924 let when1 = binary(
1929 col("a", &schema)?,
1930 Operator::Eq,
1931 lit("foo"),
1932 &batch.schema(),
1933 )?;
1934 let then1 = lit(123i32);
1935 let when2 = binary(
1936 col("a", &schema)?,
1937 Operator::Eq,
1938 lit("bar"),
1939 &batch.schema(),
1940 )?;
1941 let then2 = lit(456i64);
1942 let else_expr = lit(1.23f64);
1943
1944 let expr = generate_case_when_with_type_coercion(
1945 None,
1946 vec![(when1, then1), (when2, then2)],
1947 Some(else_expr),
1948 schema.as_ref(),
1949 );
1950 assert!(expr.is_ok());
1951 let result_type = expr.unwrap().data_type(schema.as_ref())?;
1952 assert_eq!(Float64, result_type);
1953 Ok(())
1954 }
1955
1956 #[test]
1957 fn case_eq() -> Result<()> {
1958 let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1959
1960 let when1 = lit("foo");
1961 let then1 = lit(123i32);
1962 let when2 = lit("bar");
1963 let then2 = lit(456i32);
1964 let else_value = lit(999i32);
1965
1966 let expr1 = generate_case_when_with_type_coercion(
1967 Some(col("a", &schema)?),
1968 vec![
1969 (Arc::clone(&when1), Arc::clone(&then1)),
1970 (Arc::clone(&when2), Arc::clone(&then2)),
1971 ],
1972 Some(Arc::clone(&else_value)),
1973 &schema,
1974 )?;
1975
1976 let expr2 = generate_case_when_with_type_coercion(
1977 Some(col("a", &schema)?),
1978 vec![
1979 (Arc::clone(&when1), Arc::clone(&then1)),
1980 (Arc::clone(&when2), Arc::clone(&then2)),
1981 ],
1982 Some(Arc::clone(&else_value)),
1983 &schema,
1984 )?;
1985
1986 let expr3 = generate_case_when_with_type_coercion(
1987 Some(col("a", &schema)?),
1988 vec![(Arc::clone(&when1), Arc::clone(&then1)), (when2, then2)],
1989 None,
1990 &schema,
1991 )?;
1992
1993 let expr4 = generate_case_when_with_type_coercion(
1994 Some(col("a", &schema)?),
1995 vec![(when1, then1)],
1996 Some(else_value),
1997 &schema,
1998 )?;
1999
2000 assert!(expr1.eq(&expr2));
2001 assert!(expr2.eq(&expr1));
2002
2003 assert!(expr2.ne(&expr3));
2004 assert!(expr3.ne(&expr2));
2005
2006 assert!(expr1.ne(&expr4));
2007 assert!(expr4.ne(&expr1));
2008
2009 Ok(())
2010 }
2011
2012 #[test]
2013 fn case_transform() -> Result<()> {
2014 let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
2015
2016 let when1 = lit("foo");
2017 let then1 = lit(123i32);
2018 let when2 = lit("bar");
2019 let then2 = lit(456i32);
2020 let else_value = lit(999i32);
2021
2022 let expr = generate_case_when_with_type_coercion(
2023 Some(col("a", &schema)?),
2024 vec![
2025 (Arc::clone(&when1), Arc::clone(&then1)),
2026 (Arc::clone(&when2), Arc::clone(&then2)),
2027 ],
2028 Some(Arc::clone(&else_value)),
2029 &schema,
2030 )?;
2031
2032 let expr2 = Arc::clone(&expr)
2033 .transform(|e| {
2034 let transformed = match e.as_any().downcast_ref::<Literal>() {
2035 Some(lit_value) => match lit_value.value() {
2036 ScalarValue::Utf8(Some(str_value)) => {
2037 Some(lit(str_value.to_uppercase()))
2038 }
2039 _ => None,
2040 },
2041 _ => None,
2042 };
2043 Ok(if let Some(transformed) = transformed {
2044 Transformed::yes(transformed)
2045 } else {
2046 Transformed::no(e)
2047 })
2048 })
2049 .data()
2050 .unwrap();
2051
2052 let expr3 = Arc::clone(&expr)
2053 .transform_down(|e| {
2054 let transformed = match e.as_any().downcast_ref::<Literal>() {
2055 Some(lit_value) => match lit_value.value() {
2056 ScalarValue::Utf8(Some(str_value)) => {
2057 Some(lit(str_value.to_uppercase()))
2058 }
2059 _ => None,
2060 },
2061 _ => None,
2062 };
2063 Ok(if let Some(transformed) = transformed {
2064 Transformed::yes(transformed)
2065 } else {
2066 Transformed::no(e)
2067 })
2068 })
2069 .data()
2070 .unwrap();
2071
2072 assert!(expr.ne(&expr2));
2073 assert!(expr2.eq(&expr3));
2074
2075 Ok(())
2076 }
2077
2078 #[test]
2079 fn test_column_or_null_specialization() -> Result<()> {
2080 let mut c1 = Int32Builder::new();
2082 let mut c2 = StringBuilder::new();
2083 for i in 0..1000 {
2084 c1.append_value(i);
2085 if i % 7 == 0 {
2086 c2.append_null();
2087 } else {
2088 c2.append_value(format!("string {i}"));
2089 }
2090 }
2091 let c1 = Arc::new(c1.finish());
2092 let c2 = Arc::new(c2.finish());
2093 let schema = Schema::new(vec![
2094 Field::new("c1", DataType::Int32, true),
2095 Field::new("c2", DataType::Utf8, true),
2096 ]);
2097 let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap();
2098
2099 let predicate = Arc::new(BinaryExpr::new(
2101 make_col("c1", 0),
2102 Operator::LtEq,
2103 make_lit_i32(250),
2104 ));
2105 let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?;
2106 assert!(matches!(expr.eval_method, EvalMethod::InfallibleExprOrNull));
2107 match expr.evaluate(&batch)? {
2108 ColumnarValue::Array(array) => {
2109 assert_eq!(1000, array.len());
2110 assert_eq!(785, array.null_count());
2111 }
2112 _ => unreachable!(),
2113 }
2114 Ok(())
2115 }
2116
2117 #[test]
2118 fn test_expr_or_expr_specialization() -> Result<()> {
2119 let batch = case_test_batch1()?;
2120 let schema = batch.schema();
2121 let when = binary(
2122 col("a", &schema)?,
2123 Operator::LtEq,
2124 lit(2i32),
2125 &batch.schema(),
2126 )?;
2127 let then = col("b", &schema)?;
2128 let else_expr = col("c", &schema)?;
2129 let expr = CaseExpr::try_new(None, vec![(when, then)], Some(else_expr))?;
2130 assert!(matches!(
2131 expr.eval_method,
2132 EvalMethod::ExpressionOrExpression(_)
2133 ));
2134 let result = expr
2135 .evaluate(&batch)?
2136 .into_array(batch.num_rows())
2137 .expect("Failed to convert to array");
2138 let result = as_int32_array(&result).expect("failed to downcast to Int32Array");
2139
2140 let expected = &Int32Array::from(vec![Some(3), None, Some(777), None]);
2141
2142 assert_eq!(expected, result);
2143 Ok(())
2144 }
2145
2146 fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
2147 Arc::new(Column::new(name, index))
2148 }
2149
2150 fn make_lit_i32(n: i32) -> Arc<dyn PhysicalExpr> {
2151 Arc::new(Literal::new(ScalarValue::Int32(Some(n))))
2152 }
2153
2154 fn generate_case_when_with_type_coercion(
2155 expr: Option<Arc<dyn PhysicalExpr>>,
2156 when_thens: Vec<WhenThen>,
2157 else_expr: Option<Arc<dyn PhysicalExpr>>,
2158 input_schema: &Schema,
2159 ) -> Result<Arc<dyn PhysicalExpr>> {
2160 let coerce_type =
2161 get_case_common_type(&when_thens, else_expr.clone(), input_schema);
2162 let (when_thens, else_expr) = match coerce_type {
2163 None => plan_err!(
2164 "Can't get a common type for then {when_thens:?} and else {else_expr:?} expression"
2165 ),
2166 Some(data_type) => {
2167 let left = when_thens
2169 .into_iter()
2170 .map(|(when, then)| {
2171 let then = try_cast(then, input_schema, data_type.clone())?;
2172 Ok((when, then))
2173 })
2174 .collect::<Result<Vec<_>>>()?;
2175 let right = match else_expr {
2176 None => None,
2177 Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?),
2178 };
2179
2180 Ok((left, right))
2181 }
2182 }?;
2183 case(expr, when_thens, else_expr)
2184 }
2185
2186 fn get_case_common_type(
2187 when_thens: &[WhenThen],
2188 else_expr: Option<Arc<dyn PhysicalExpr>>,
2189 input_schema: &Schema,
2190 ) -> Option<DataType> {
2191 let thens_type = when_thens
2192 .iter()
2193 .map(|when_then| {
2194 let data_type = &when_then.1.data_type(input_schema).unwrap();
2195 data_type.clone()
2196 })
2197 .collect::<Vec<_>>();
2198 let else_type = match else_expr {
2199 None => {
2200 thens_type[0].clone()
2202 }
2203 Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(),
2204 };
2205 thens_type
2206 .iter()
2207 .try_fold(else_type, |left_type, right_type| {
2208 comparison_coercion(&left_type, right_type)
2211 })
2212 }
2213
2214 #[test]
2215 fn test_fmt_sql() -> Result<()> {
2216 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
2217
2218 let when = binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?;
2220 let then = lit(123.3f64);
2221 let else_value = lit(999i32);
2222
2223 let expr = generate_case_when_with_type_coercion(
2224 None,
2225 vec![(when, then)],
2226 Some(else_value),
2227 &schema,
2228 )?;
2229
2230 let display_string = expr.to_string();
2231 assert_eq!(
2232 display_string,
2233 "CASE WHEN a@0 = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END"
2234 );
2235
2236 let sql_string = fmt_sql(expr.as_ref()).to_string();
2237 assert_eq!(
2238 sql_string,
2239 "CASE WHEN a = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END"
2240 );
2241
2242 Ok(())
2243 }
2244
2245 #[test]
2246 fn test_merge_n() {
2247 let a1 = StringArray::from(vec![Some("A")]).to_data();
2248 let a2 = StringArray::from(vec![Some("B")]).to_data();
2249 let a3 = StringArray::from(vec![Some("C"), Some("D")]).to_data();
2250
2251 let indices = vec![
2252 PartialResultIndex::none(),
2253 PartialResultIndex::try_new(1).unwrap(),
2254 PartialResultIndex::try_new(0).unwrap(),
2255 PartialResultIndex::none(),
2256 PartialResultIndex::try_new(2).unwrap(),
2257 PartialResultIndex::try_new(2).unwrap(),
2258 ];
2259
2260 let merged = merge_n(&[a1, a2, a3], &indices).unwrap();
2261 let merged = merged.as_string::<i32>();
2262
2263 assert_eq!(merged.len(), indices.len());
2264 assert!(!merged.is_valid(0));
2265 assert!(merged.is_valid(1));
2266 assert_eq!(merged.value(1), "B");
2267 assert!(merged.is_valid(2));
2268 assert_eq!(merged.value(2), "A");
2269 assert!(!merged.is_valid(3));
2270 assert!(merged.is_valid(4));
2271 assert_eq!(merged.value(4), "C");
2272 assert!(merged.is_valid(5));
2273 assert_eq!(merged.value(5), "D");
2274 }
2275
2276 #[test]
2277 fn test_merge() {
2278 let a1 = Arc::new(StringArray::from(vec![Some("A"), Some("C")]));
2279 let a2 = Arc::new(StringArray::from(vec![Some("B")]));
2280
2281 let mask = BooleanArray::from(vec![true, false, true]);
2282
2283 let merged =
2284 merge(&mask, ColumnarValue::Array(a1), ColumnarValue::Array(a2)).unwrap();
2285 let merged = merged.as_string::<i32>();
2286
2287 assert_eq!(merged.len(), mask.len());
2288 assert!(merged.is_valid(0));
2289 assert_eq!(merged.value(0), "A");
2290 assert!(merged.is_valid(1));
2291 assert_eq!(merged.value(1), "B");
2292 assert!(merged.is_valid(2));
2293 assert_eq!(merged.value(2), "C");
2294 }
2295}