1mod literal_lookup_table;
19
20use super::{Column, Literal};
21use crate::PhysicalExpr;
22use crate::expressions::{lit, try_cast};
23use arrow::array::*;
24use arrow::compute::kernels::zip::zip;
25use arrow::compute::{
26 FilterBuilder, FilterPredicate, is_not_null, not, nullif, prep_null_mask_filter,
27};
28use arrow::datatypes::{DataType, Schema, UInt32Type, UnionMode};
29use arrow::error::ArrowError;
30use datafusion_common::cast::as_boolean_array;
31use datafusion_common::{
32 DataFusionError, Result, ScalarValue, assert_or_internal_err, exec_err,
33 internal_datafusion_err, internal_err,
34};
35use datafusion_expr::ColumnarValue;
36use indexmap::{IndexMap, IndexSet};
37use std::borrow::Cow;
38use std::hash::Hash;
39use std::{any::Any, sync::Arc};
40
41use crate::expressions::case::literal_lookup_table::LiteralLookupTable;
42use arrow::compute::kernels::merge::{MergeIndex, merge, merge_n};
43use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
44use datafusion_physical_expr_common::datum::compare_with_eq;
45use itertools::Itertools;
46use std::fmt::{Debug, Formatter};
47
48pub(super) type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
49
50#[derive(Debug, Hash, PartialEq, Eq)]
51enum EvalMethod {
52 NoExpression(ProjectedCaseBody),
57 WithExpression(ProjectedCaseBody),
63 InfallibleExprOrNull,
69 ScalarOrScalar,
74 ExpressionOrExpression(ProjectedCaseBody),
79
80 WithExprScalarLookupTable(LiteralLookupTable),
84}
85
86impl Hash for LiteralLookupTable {
93 fn hash<H: std::hash::Hasher>(&self, _state: &mut H) {}
94}
95
96impl PartialEq for LiteralLookupTable {
103 fn eq(&self, _other: &Self) -> bool {
104 true
105 }
106}
107
108impl Eq for LiteralLookupTable {}
109
110#[derive(Debug, Hash, PartialEq, Eq)]
113struct CaseBody {
114 expr: Option<Arc<dyn PhysicalExpr>>,
116 when_then_expr: Vec<WhenThen>,
118 else_expr: Option<Arc<dyn PhysicalExpr>>,
120}
121
122impl CaseBody {
123 fn project(&self) -> Result<ProjectedCaseBody> {
125 let mut used_column_indices = IndexSet::<usize>::new();
127 let mut collect_column_indices = |expr: &Arc<dyn PhysicalExpr>| {
128 expr.apply(|expr| {
129 if let Some(column) = expr.as_any().downcast_ref::<Column>() {
130 used_column_indices.insert(column.index());
131 }
132 Ok(TreeNodeRecursion::Continue)
133 })
134 .expect("Closure cannot fail");
135 };
136
137 if let Some(e) = &self.expr {
138 collect_column_indices(e);
139 }
140 self.when_then_expr.iter().for_each(|(w, t)| {
141 collect_column_indices(w);
142 collect_column_indices(t);
143 });
144 if let Some(e) = &self.else_expr {
145 collect_column_indices(e);
146 }
147
148 let column_index_map = used_column_indices
150 .iter()
151 .enumerate()
152 .map(|(projected, original)| (*original, projected))
153 .collect::<IndexMap<usize, usize>>();
154
155 let project = |expr: &Arc<dyn PhysicalExpr>| -> Result<Arc<dyn PhysicalExpr>> {
158 Arc::clone(expr)
159 .transform_down(|e| {
160 if let Some(column) = e.as_any().downcast_ref::<Column>() {
161 let original = column.index();
162 let projected = *column_index_map.get(&original).unwrap();
163 if projected != original {
164 return Ok(Transformed::yes(Arc::new(Column::new(
165 column.name(),
166 projected,
167 ))));
168 }
169 }
170 Ok(Transformed::no(e))
171 })
172 .map(|t| t.data)
173 };
174
175 let projected_body = CaseBody {
176 expr: self.expr.as_ref().map(project).transpose()?,
177 when_then_expr: self
178 .when_then_expr
179 .iter()
180 .map(|(e, t)| Ok((project(e)?, project(t)?)))
181 .collect::<Result<Vec<_>>>()?,
182 else_expr: self.else_expr.as_ref().map(project).transpose()?,
183 };
184
185 let projection = column_index_map
187 .iter()
188 .sorted_by_key(|(_, v)| **v)
189 .map(|(k, _)| *k)
190 .collect::<Vec<_>>();
191
192 Ok(ProjectedCaseBody {
193 projection,
194 body: projected_body,
195 })
196 }
197}
198
199#[derive(Debug, Hash, PartialEq, Eq)]
227struct ProjectedCaseBody {
228 projection: Vec<usize>,
229 body: CaseBody,
230}
231
232#[derive(Debug, Hash, PartialEq, Eq)]
250pub struct CaseExpr {
251 body: CaseBody,
253 eval_method: EvalMethod,
255}
256
257impl std::fmt::Display for CaseExpr {
258 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
259 write!(f, "CASE ")?;
260 if let Some(e) = &self.body.expr {
261 write!(f, "{e} ")?;
262 }
263 for (w, t) in &self.body.when_then_expr {
264 write!(f, "WHEN {w} THEN {t} ")?;
265 }
266 if let Some(e) = &self.body.else_expr {
267 write!(f, "ELSE {e} ")?;
268 }
269 write!(f, "END")
270 }
271}
272
273fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) -> bool {
279 expr.as_any().is::<Column>()
280}
281
282fn create_filter(predicate: &BooleanArray, optimize: bool) -> FilterPredicate {
284 let mut filter_builder = FilterBuilder::new(predicate);
285 if optimize {
286 filter_builder = filter_builder.optimize();
288 }
289 filter_builder.build()
290}
291
292fn multiple_arrays(data_type: &DataType) -> bool {
293 match data_type {
294 DataType::Struct(fields) => {
295 fields.len() > 1
296 || fields.len() == 1 && multiple_arrays(fields[0].data_type())
297 }
298 DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
299 _ => false,
300 }
301}
302
303fn filter_record_batch(
306 record_batch: &RecordBatch,
307 filter: &FilterPredicate,
308) -> std::result::Result<RecordBatch, ArrowError> {
309 let filtered_columns = record_batch
310 .columns()
311 .iter()
312 .map(|a| filter_array(a, filter))
313 .collect::<std::result::Result<Vec<_>, _>>()?;
314 unsafe {
320 Ok(RecordBatch::new_unchecked(
321 record_batch.schema(),
322 filtered_columns,
323 filter.count(),
324 ))
325 }
326}
327
328#[inline(always)]
333fn filter_array(
334 array: &dyn Array,
335 filter: &FilterPredicate,
336) -> std::result::Result<ArrayRef, ArrowError> {
337 filter.filter(array)
338}
339
340#[derive(Copy, Clone, PartialEq, Eq)]
345struct PartialResultIndex {
346 index: u32,
347}
348
349const NONE_VALUE: u32 = u32::MAX;
350
351impl PartialResultIndex {
352 fn none() -> Self {
354 Self { index: NONE_VALUE }
355 }
356
357 fn zero() -> Self {
358 Self { index: 0 }
359 }
360
361 fn try_new(index: usize) -> Result<Self> {
366 let Ok(index) = u32::try_from(index) else {
367 return internal_err!("Partial result index exceeds limit");
368 };
369
370 assert_or_internal_err!(
371 index != NONE_VALUE,
372 "Partial result index exceeds limit"
373 );
374
375 Ok(Self { index })
376 }
377
378 fn is_none(&self) -> bool {
380 self.index == NONE_VALUE
381 }
382}
383
384impl MergeIndex for PartialResultIndex {
385 fn index(&self) -> Option<usize> {
387 if self.is_none() {
388 None
389 } else {
390 Some(self.index as usize)
391 }
392 }
393}
394
395impl Debug for PartialResultIndex {
396 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
397 if self.is_none() {
398 write!(f, "null")
399 } else {
400 write!(f, "{}", self.index)
401 }
402 }
403}
404
405enum ResultState {
406 Empty,
408 Partial {
410 arrays: Vec<ArrayRef>,
413 indices: Vec<PartialResultIndex>,
415 },
416 Complete(ColumnarValue),
419}
420
421struct ResultBuilder {
431 data_type: DataType,
432 row_count: usize,
434 state: ResultState,
435}
436
437impl ResultBuilder {
438 fn new(data_type: &DataType, row_count: usize) -> Self {
442 Self {
443 data_type: data_type.clone(),
444 row_count,
445 state: ResultState::Empty,
446 }
447 }
448
449 fn add_branch_result(
482 &mut self,
483 row_indices: &ArrayRef,
484 value: ColumnarValue,
485 ) -> Result<()> {
486 match value {
487 ColumnarValue::Array(a) => {
488 if a.len() != row_indices.len() {
489 internal_err!("Array length must match row indices length")
490 } else if row_indices.len() == self.row_count {
491 self.set_complete_result(ColumnarValue::Array(a))
492 } else {
493 self.add_partial_result(row_indices, a)
494 }
495 }
496 ColumnarValue::Scalar(s) => {
497 if row_indices.len() == self.row_count {
498 self.set_complete_result(ColumnarValue::Scalar(s))
499 } else {
500 self.add_partial_result(
501 row_indices,
502 s.to_array_of_size(row_indices.len())?,
503 )
504 }
505 }
506 }
507 }
508
509 fn add_partial_result(
515 &mut self,
516 row_indices: &ArrayRef,
517 row_values: ArrayRef,
518 ) -> Result<()> {
519 assert_or_internal_err!(
520 row_indices.null_count() == 0,
521 "Row indices must not contain nulls"
522 );
523
524 match &mut self.state {
525 ResultState::Empty => {
526 let array_index = PartialResultIndex::zero();
527 let mut indices = vec![PartialResultIndex::none(); self.row_count];
528 for row_ix in row_indices.as_primitive::<UInt32Type>().values().iter() {
529 indices[*row_ix as usize] = array_index;
530 }
531
532 self.state = ResultState::Partial {
533 arrays: vec![row_values],
534 indices,
535 };
536
537 Ok(())
538 }
539 ResultState::Partial { arrays, indices } => {
540 let array_index = PartialResultIndex::try_new(arrays.len())?;
541
542 arrays.push(row_values);
543
544 for row_ix in row_indices.as_primitive::<UInt32Type>().values().iter() {
545 #[cfg(debug_assertions)]
549 assert_or_internal_err!(
550 indices[*row_ix as usize].is_none(),
551 "Duplicate value for row {}",
552 *row_ix
553 );
554
555 indices[*row_ix as usize] = array_index;
556 }
557 Ok(())
558 }
559 ResultState::Complete(_) => internal_err!(
560 "Cannot add a partial result when complete result is already set"
561 ),
562 }
563 }
564
565 fn set_complete_result(&mut self, value: ColumnarValue) -> Result<()> {
571 match &self.state {
572 ResultState::Empty => {
573 self.state = ResultState::Complete(value);
574 Ok(())
575 }
576 ResultState::Partial { .. } => {
577 internal_err!(
578 "Cannot set a complete result when there are already partial results"
579 )
580 }
581 ResultState::Complete(_) => internal_err!("Complete result already set"),
582 }
583 }
584
585 fn finish(self) -> Result<ColumnarValue> {
587 match self.state {
588 ResultState::Empty => {
589 Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
593 &self.data_type,
594 )?))
595 }
596 ResultState::Partial { arrays, indices } => {
597 let array_refs = arrays.iter().map(|a| a.as_ref()).collect::<Vec<_>>();
599 Ok(ColumnarValue::Array(merge_n(&array_refs, &indices)?))
600 }
601 ResultState::Complete(v) => {
602 Ok(v)
604 }
605 }
606 }
607}
608
609impl CaseExpr {
610 pub fn try_new(
612 expr: Option<Arc<dyn PhysicalExpr>>,
613 when_then_expr: Vec<WhenThen>,
614 else_expr: Option<Arc<dyn PhysicalExpr>>,
615 ) -> Result<Self> {
616 let else_expr = match &else_expr {
619 Some(e) => match e.as_any().downcast_ref::<Literal>() {
620 Some(lit) if lit.value().is_null() => None,
621 _ => else_expr,
622 },
623 _ => else_expr,
624 };
625
626 if when_then_expr.is_empty() {
627 return exec_err!("There must be at least one WHEN clause");
628 }
629
630 let body = CaseBody {
631 expr,
632 when_then_expr,
633 else_expr,
634 };
635
636 let eval_method = Self::find_best_eval_method(&body)?;
637
638 Ok(Self { body, eval_method })
639 }
640
641 fn find_best_eval_method(body: &CaseBody) -> Result<EvalMethod> {
642 if body.expr.is_some() {
643 if let Some(mapping) = LiteralLookupTable::maybe_new(body) {
644 return Ok(EvalMethod::WithExprScalarLookupTable(mapping));
645 }
646
647 return Ok(EvalMethod::WithExpression(body.project()?));
648 }
649
650 Ok(
651 if body.when_then_expr.len() == 1
652 && is_cheap_and_infallible(&(body.when_then_expr[0].1))
653 && body.else_expr.is_none()
654 {
655 EvalMethod::InfallibleExprOrNull
656 } else if body.when_then_expr.len() == 1
657 && body.when_then_expr[0].1.as_any().is::<Literal>()
658 && body.else_expr.is_some()
659 && body.else_expr.as_ref().unwrap().as_any().is::<Literal>()
660 {
661 EvalMethod::ScalarOrScalar
662 } else if body.when_then_expr.len() == 1 && body.else_expr.is_some() {
663 EvalMethod::ExpressionOrExpression(body.project()?)
664 } else {
665 EvalMethod::NoExpression(body.project()?)
666 },
667 )
668 }
669
670 pub fn expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
672 self.body.expr.as_ref()
673 }
674
675 pub fn when_then_expr(&self) -> &[WhenThen] {
677 &self.body.when_then_expr
678 }
679
680 pub fn else_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
682 self.body.else_expr.as_ref()
683 }
684}
685
686impl CaseBody {
687 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
688 let mut data_type = DataType::Null;
691 for i in 0..self.when_then_expr.len() {
692 data_type = self.when_then_expr[i].1.data_type(input_schema)?;
693 if !data_type.equals_datatype(&DataType::Null) {
694 break;
695 }
696 }
697 if data_type.equals_datatype(&DataType::Null)
699 && let Some(e) = &self.else_expr
700 {
701 data_type = e.data_type(input_schema)?;
702 }
703
704 Ok(data_type)
705 }
706
707 fn case_when_with_expr(
709 &self,
710 batch: &RecordBatch,
711 return_type: &DataType,
712 ) -> Result<ColumnarValue> {
713 let mut result_builder = ResultBuilder::new(return_type, batch.num_rows());
714
715 let mut remainder_rows: ArrayRef =
717 Arc::new(UInt32Array::from_iter_values(0..batch.num_rows() as u32));
718 let mut remainder_batch = Cow::Borrowed(batch);
720
721 let mut base_values = self
723 .expr
724 .as_ref()
725 .unwrap()
726 .evaluate(batch)?
727 .into_array(batch.num_rows())?;
728
729 let base_null_count = base_values.logical_null_count();
734 if base_null_count > 0 {
735 let base_not_nulls = is_not_null(base_values.as_ref())?;
739 let base_all_null = base_null_count == remainder_batch.num_rows();
740
741 if let Some(e) = &self.else_expr {
744 let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
745
746 if base_all_null {
747 let nulls_value = expr.evaluate(&remainder_batch)?;
749 result_builder.add_branch_result(&remainder_rows, nulls_value)?;
750 } else {
751 let nulls_filter = create_filter(¬(&base_not_nulls)?, true);
753 let nulls_batch =
754 filter_record_batch(&remainder_batch, &nulls_filter)?;
755 let nulls_rows = filter_array(&remainder_rows, &nulls_filter)?;
756 let nulls_value = expr.evaluate(&nulls_batch)?;
757 result_builder.add_branch_result(&nulls_rows, nulls_value)?;
758 }
759 }
760
761 if base_all_null {
763 return result_builder.finish();
764 }
765
766 let not_null_filter = create_filter(&base_not_nulls, true);
768 remainder_batch =
769 Cow::Owned(filter_record_batch(&remainder_batch, ¬_null_filter)?);
770 remainder_rows = filter_array(&remainder_rows, ¬_null_filter)?;
771 base_values = filter_array(&base_values, ¬_null_filter)?;
772 }
773
774 let base_value_is_nested = base_values.data_type().is_nested();
777
778 for i in 0..self.when_then_expr.len() {
779 let when_expr = &self.when_then_expr[i].0;
782 let when_value = match when_expr.evaluate(&remainder_batch)? {
783 ColumnarValue::Array(a) => {
784 compare_with_eq(&a, &base_values, base_value_is_nested)
785 }
786 ColumnarValue::Scalar(s) => {
787 compare_with_eq(&s.to_scalar()?, &base_values, base_value_is_nested)
788 }
789 }?;
790
791 let when_true_count = when_value.true_count();
794
795 if when_true_count == 0 {
797 continue;
798 }
799
800 if when_true_count == remainder_batch.num_rows() {
802 let then_expression = &self.when_then_expr[i].1;
803 let then_value = then_expression.evaluate(&remainder_batch)?;
804 result_builder.add_branch_result(&remainder_rows, then_value)?;
805 return result_builder.finish();
806 }
807
808 let then_filter = create_filter(&when_value, true);
814 let then_batch = filter_record_batch(&remainder_batch, &then_filter)?;
815 let then_rows = filter_array(&remainder_rows, &then_filter)?;
816
817 let then_expression = &self.when_then_expr[i].1;
818 let then_value = then_expression.evaluate(&then_batch)?;
819 result_builder.add_branch_result(&then_rows, then_value)?;
820
821 if self.else_expr.is_none() && i == self.when_then_expr.len() - 1 {
824 return result_builder.finish();
825 }
826
827 let next_selection = match when_value.null_count() {
829 0 => not(&when_value),
830 _ => {
831 not(&prep_null_mask_filter(&when_value))
834 }
835 }?;
836 let next_filter = create_filter(&next_selection, true);
837 remainder_batch =
838 Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?);
839 remainder_rows = filter_array(&remainder_rows, &next_filter)?;
840 base_values = filter_array(&base_values, &next_filter)?;
841 }
842
843 if let Some(e) = &self.else_expr {
846 let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
848 let else_value = expr.evaluate(&remainder_batch)?;
849 result_builder.add_branch_result(&remainder_rows, else_value)?;
850 }
851
852 result_builder.finish()
853 }
854
855 fn case_when_no_expr(
857 &self,
858 batch: &RecordBatch,
859 return_type: &DataType,
860 ) -> Result<ColumnarValue> {
861 let mut result_builder = ResultBuilder::new(return_type, batch.num_rows());
862
863 let mut remainder_rows: ArrayRef =
865 Arc::new(UInt32Array::from_iter(0..batch.num_rows() as u32));
866 let mut remainder_batch = Cow::Borrowed(batch);
868
869 for i in 0..self.when_then_expr.len() {
870 let when_predicate = &self.when_then_expr[i].0;
873 let when_value = when_predicate
874 .evaluate(&remainder_batch)?
875 .into_array(remainder_batch.num_rows())?;
876 let when_value = as_boolean_array(&when_value).map_err(|_| {
877 internal_datafusion_err!("WHEN expression did not return a BooleanArray")
878 })?;
879
880 let when_true_count = when_value.true_count();
883
884 if when_true_count == 0 {
886 continue;
887 }
888
889 if when_true_count == remainder_batch.num_rows() {
891 let then_expression = &self.when_then_expr[i].1;
892 let then_value = then_expression.evaluate(&remainder_batch)?;
893 result_builder.add_branch_result(&remainder_rows, then_value)?;
894 return result_builder.finish();
895 }
896
897 let then_filter = create_filter(when_value, true);
903 let then_batch = filter_record_batch(&remainder_batch, &then_filter)?;
904 let then_rows = filter_array(&remainder_rows, &then_filter)?;
905
906 let then_expression = &self.when_then_expr[i].1;
907 let then_value = then_expression.evaluate(&then_batch)?;
908 result_builder.add_branch_result(&then_rows, then_value)?;
909
910 if self.else_expr.is_none() && i == self.when_then_expr.len() - 1 {
913 return result_builder.finish();
914 }
915
916 let next_selection = match when_value.null_count() {
918 0 => not(when_value),
919 _ => {
920 not(&prep_null_mask_filter(when_value))
923 }
924 }?;
925 let next_filter = create_filter(&next_selection, true);
926 remainder_batch =
927 Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?);
928 remainder_rows = filter_array(&remainder_rows, &next_filter)?;
929 }
930
931 if let Some(e) = &self.else_expr {
934 let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
936 let else_value = expr.evaluate(&remainder_batch)?;
937 result_builder.add_branch_result(&remainder_rows, else_value)?;
938 }
939
940 result_builder.finish()
941 }
942
943 fn expr_or_expr(
945 &self,
946 batch: &RecordBatch,
947 when_value: &BooleanArray,
948 ) -> Result<ColumnarValue> {
949 let when_value = match when_value.null_count() {
950 0 => Cow::Borrowed(when_value),
951 _ => {
952 Cow::Owned(prep_null_mask_filter(when_value))
954 }
955 };
956
957 let optimize_filter = batch.num_columns() > 1
958 || (batch.num_columns() == 1 && multiple_arrays(batch.column(0).data_type()));
959
960 let when_filter = create_filter(&when_value, optimize_filter);
961 let then_batch = filter_record_batch(batch, &when_filter)?;
962 let then_value = self.when_then_expr[0].1.evaluate(&then_batch)?;
963
964 let else_selection = not(&when_value)?;
965 let else_filter = create_filter(&else_selection, optimize_filter);
966 let else_batch = filter_record_batch(batch, &else_filter)?;
967
968 let e = self.else_expr.as_ref().unwrap();
970 let return_type = self.data_type(&batch.schema())?;
971 let else_expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())
972 .unwrap_or_else(|_| Arc::clone(e));
973
974 let else_value = else_expr.evaluate(&else_batch)?;
975
976 Ok(ColumnarValue::Array(match (then_value, else_value) {
977 (ColumnarValue::Array(t), ColumnarValue::Array(e)) => {
978 merge(&when_value, &t, &e)
979 }
980 (ColumnarValue::Scalar(t), ColumnarValue::Array(e)) => {
981 merge(&when_value, &t.to_scalar()?, &e)
982 }
983 (ColumnarValue::Array(t), ColumnarValue::Scalar(e)) => {
984 merge(&when_value, &t, &e.to_scalar()?)
985 }
986 (ColumnarValue::Scalar(t), ColumnarValue::Scalar(e)) => {
987 merge(&when_value, &t.to_scalar()?, &e.to_scalar()?)
988 }
989 }?))
990 }
991}
992
993impl CaseExpr {
994 fn case_when_with_expr(
1002 &self,
1003 batch: &RecordBatch,
1004 projected: &ProjectedCaseBody,
1005 ) -> Result<ColumnarValue> {
1006 let return_type = self.data_type(&batch.schema())?;
1007 if projected.projection.len() < batch.num_columns() {
1008 let projected_batch = batch.project(&projected.projection)?;
1009 projected
1010 .body
1011 .case_when_with_expr(&projected_batch, &return_type)
1012 } else {
1013 self.body.case_when_with_expr(batch, &return_type)
1014 }
1015 }
1016
1017 fn case_when_no_expr(
1025 &self,
1026 batch: &RecordBatch,
1027 projected: &ProjectedCaseBody,
1028 ) -> Result<ColumnarValue> {
1029 let return_type = self.data_type(&batch.schema())?;
1030 if projected.projection.len() < batch.num_columns() {
1031 let projected_batch = batch.project(&projected.projection)?;
1032 projected
1033 .body
1034 .case_when_no_expr(&projected_batch, &return_type)
1035 } else {
1036 self.body.case_when_no_expr(batch, &return_type)
1037 }
1038 }
1039
1040 fn case_column_or_null(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
1050 let when_expr = &self.body.when_then_expr[0].0;
1051 let then_expr = &self.body.when_then_expr[0].1;
1052
1053 match when_expr.evaluate(batch)? {
1054 ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))) => {
1056 then_expr.evaluate(batch)
1057 }
1058 ColumnarValue::Scalar(_) => {
1060 ScalarValue::try_from(self.data_type(&batch.schema())?)
1062 .map(ColumnarValue::Scalar)
1063 }
1064 ColumnarValue::Array(bit_mask) => {
1066 let bit_mask = bit_mask
1067 .as_any()
1068 .downcast_ref::<BooleanArray>()
1069 .expect("predicate should evaluate to a boolean array");
1070 let bit_mask = match bit_mask.null_count() {
1072 0 => not(bit_mask)?,
1073 _ => not(&prep_null_mask_filter(bit_mask))?,
1074 };
1075 match then_expr.evaluate(batch)? {
1076 ColumnarValue::Array(array) => {
1077 Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?))
1078 }
1079 ColumnarValue::Scalar(_) => {
1080 internal_err!("expression did not evaluate to an array")
1081 }
1082 }
1083 }
1084 }
1085 }
1086
1087 fn scalar_or_scalar(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
1088 let return_type = self.data_type(&batch.schema())?;
1089
1090 let when_value = self.body.when_then_expr[0].0.evaluate(batch)?;
1092 let when_value = when_value.into_array(batch.num_rows())?;
1093 let when_value = as_boolean_array(&when_value).map_err(|_| {
1094 internal_datafusion_err!("WHEN expression did not return a BooleanArray")
1095 })?;
1096
1097 let when_value = match when_value.null_count() {
1099 0 => Cow::Borrowed(when_value),
1100 _ => Cow::Owned(prep_null_mask_filter(when_value)),
1101 };
1102
1103 let then_value = self.body.when_then_expr[0].1.evaluate(batch)?;
1105 let then_value = Scalar::new(then_value.into_array(1)?);
1106
1107 let Some(e) = &self.body.else_expr else {
1108 return internal_err!("expression did not evaluate to an array");
1109 };
1110 let expr = try_cast(Arc::clone(e), &batch.schema(), return_type)?;
1112 let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?);
1113 Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?))
1114 }
1115
1116 fn expr_or_expr(
1117 &self,
1118 batch: &RecordBatch,
1119 projected: &ProjectedCaseBody,
1120 ) -> Result<ColumnarValue> {
1121 let when_value = self.body.when_then_expr[0].0.evaluate(batch)?;
1123 let when_value = when_value.into_array(1)?;
1127 let when_value = as_boolean_array(&when_value).map_err(|e| {
1128 DataFusionError::Context(
1129 "WHEN expression did not return a BooleanArray".to_string(),
1130 Box::new(e),
1131 )
1132 })?;
1133
1134 let true_count = when_value.true_count();
1135 if true_count == when_value.len() {
1136 self.body.when_then_expr[0].1.evaluate(batch)
1138 } else if true_count == 0 {
1139 self.body.else_expr.as_ref().unwrap().evaluate(batch)
1141 } else if projected.projection.len() < batch.num_columns() {
1142 let projected_batch = batch.project(&projected.projection)?;
1145 projected.body.expr_or_expr(&projected_batch, when_value)
1146 } else {
1147 self.body.expr_or_expr(batch, when_value)
1149 }
1150 }
1151
1152 fn with_lookup_table(
1153 &self,
1154 batch: &RecordBatch,
1155 lookup_table: &LiteralLookupTable,
1156 ) -> Result<ColumnarValue> {
1157 let expr = self.body.expr.as_ref().unwrap();
1158 let evaluated_expression = expr.evaluate(batch)?;
1159
1160 let is_scalar = matches!(evaluated_expression, ColumnarValue::Scalar(_));
1161 let evaluated_expression = evaluated_expression.to_array(1)?;
1162
1163 let values = lookup_table.map_keys_to_values(&evaluated_expression)?;
1164
1165 let result = if is_scalar {
1166 ColumnarValue::Scalar(ScalarValue::try_from_array(values.as_ref(), 0)?)
1167 } else {
1168 ColumnarValue::Array(values)
1169 };
1170
1171 Ok(result)
1172 }
1173}
1174
1175impl PhysicalExpr for CaseExpr {
1176 fn as_any(&self) -> &dyn Any {
1178 self
1179 }
1180
1181 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
1182 self.body.data_type(input_schema)
1183 }
1184
1185 fn nullable(&self, input_schema: &Schema) -> Result<bool> {
1186 let nullable_then = self
1187 .body
1188 .when_then_expr
1189 .iter()
1190 .filter_map(|(w, t)| {
1191 let is_nullable = match t.nullable(input_schema) {
1192 Err(e) => return Some(Err(e)),
1194 Ok(n) => n,
1195 };
1196
1197 if !is_nullable {
1200 return None;
1201 }
1202
1203 if self.body.expr.is_some() {
1205 return Some(Ok(()));
1206 }
1207
1208 let with_null = match replace_with_null(w, t.as_ref(), input_schema) {
1214 Err(e) => return Some(Err(e)),
1215 Ok(e) => e,
1216 };
1217
1218 let predicate_result = match evaluate_predicate(&with_null) {
1220 Err(e) => return Some(Err(e)),
1221 Ok(b) => b,
1222 };
1223
1224 match predicate_result {
1225 None | Some(true) => Some(Ok(())),
1227 Some(false) => None,
1230 }
1231 })
1232 .next();
1233
1234 if let Some(nullable_then) = nullable_then {
1235 nullable_then.map(|_| true)
1239 } else if let Some(e) = &self.body.else_expr {
1240 e.nullable(input_schema)
1243 } else {
1244 Ok(true)
1247 }
1248 }
1249
1250 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
1251 match &self.eval_method {
1252 EvalMethod::WithExpression(p) => {
1253 self.case_when_with_expr(batch, p)
1256 }
1257 EvalMethod::NoExpression(p) => {
1258 self.case_when_no_expr(batch, p)
1261 }
1262 EvalMethod::InfallibleExprOrNull => {
1263 self.case_column_or_null(batch)
1265 }
1266 EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch),
1267 EvalMethod::ExpressionOrExpression(p) => self.expr_or_expr(batch, p),
1268 EvalMethod::WithExprScalarLookupTable(lookup_table) => {
1269 self.with_lookup_table(batch, lookup_table)
1270 }
1271 }
1272 }
1273
1274 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
1275 let mut children = vec![];
1276 if let Some(expr) = &self.body.expr {
1277 children.push(expr)
1278 }
1279 self.body.when_then_expr.iter().for_each(|(cond, value)| {
1280 children.push(cond);
1281 children.push(value);
1282 });
1283
1284 if let Some(else_expr) = &self.body.else_expr {
1285 children.push(else_expr)
1286 }
1287 children
1288 }
1289
1290 fn with_new_children(
1292 self: Arc<Self>,
1293 children: Vec<Arc<dyn PhysicalExpr>>,
1294 ) -> Result<Arc<dyn PhysicalExpr>> {
1295 if children.len() != self.children().len() {
1296 internal_err!("CaseExpr: Wrong number of children")
1297 } else {
1298 let (expr, when_then_expr, else_expr) =
1299 match (self.expr().is_some(), self.body.else_expr.is_some()) {
1300 (true, true) => (
1301 Some(&children[0]),
1302 &children[1..children.len() - 1],
1303 Some(&children[children.len() - 1]),
1304 ),
1305 (true, false) => {
1306 (Some(&children[0]), &children[1..children.len()], None)
1307 }
1308 (false, true) => (
1309 None,
1310 &children[0..children.len() - 1],
1311 Some(&children[children.len() - 1]),
1312 ),
1313 (false, false) => (None, &children[0..children.len()], None),
1314 };
1315 Ok(Arc::new(CaseExpr::try_new(
1316 expr.cloned(),
1317 when_then_expr.iter().cloned().tuples().collect(),
1318 else_expr.cloned(),
1319 )?))
1320 }
1321 }
1322
1323 fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1324 write!(f, "CASE ")?;
1325 if let Some(e) = &self.body.expr {
1326 e.fmt_sql(f)?;
1327 write!(f, " ")?;
1328 }
1329
1330 for (w, t) in &self.body.when_then_expr {
1331 write!(f, "WHEN ")?;
1332 w.fmt_sql(f)?;
1333 write!(f, " THEN ")?;
1334 t.fmt_sql(f)?;
1335 write!(f, " ")?;
1336 }
1337
1338 if let Some(e) = &self.body.else_expr {
1339 write!(f, "ELSE ")?;
1340 e.fmt_sql(f)?;
1341 write!(f, " ")?;
1342 }
1343 write!(f, "END")
1344 }
1345}
1346
1347fn evaluate_predicate(predicate: &Arc<dyn PhysicalExpr>) -> Result<Option<bool>> {
1353 let batch = RecordBatch::try_new_with_options(
1355 Arc::new(Schema::empty()),
1356 vec![],
1357 &RecordBatchOptions::new().with_row_count(Some(1)),
1358 )?;
1359
1360 let result = match predicate.evaluate(&batch) {
1362 Err(_) => None,
1364 Ok(ColumnarValue::Array(array)) => Some(
1365 ScalarValue::try_from_array(array.as_ref(), 0)?
1366 .cast_to(&DataType::Boolean)?,
1367 ),
1368 Ok(ColumnarValue::Scalar(scalar)) => Some(scalar.cast_to(&DataType::Boolean)?),
1369 };
1370 Ok(result.map(|v| matches!(v, ScalarValue::Boolean(Some(true)))))
1371}
1372
1373fn replace_with_null(
1374 expr: &Arc<dyn PhysicalExpr>,
1375 expr_to_replace: &dyn PhysicalExpr,
1376 input_schema: &Schema,
1377) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
1378 let with_null = Arc::clone(expr)
1379 .transform_down(|e| {
1380 if e.as_ref().dyn_eq(expr_to_replace) {
1381 let data_type = e.data_type(input_schema)?;
1382 let null_literal = lit(ScalarValue::try_new_null(&data_type)?);
1383 Ok(Transformed::yes(null_literal))
1384 } else {
1385 Ok(Transformed::no(e))
1386 }
1387 })?
1388 .data;
1389 Ok(with_null)
1390}
1391
1392pub fn case(
1394 expr: Option<Arc<dyn PhysicalExpr>>,
1395 when_thens: Vec<WhenThen>,
1396 else_expr: Option<Arc<dyn PhysicalExpr>>,
1397) -> Result<Arc<dyn PhysicalExpr>> {
1398 Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?))
1399}
1400
1401#[cfg(test)]
1402mod tests {
1403 use super::*;
1404
1405 use crate::expressions;
1406 use crate::expressions::{BinaryExpr, binary, cast, col, is_not_null, lit};
1407 use arrow::buffer::Buffer;
1408 use arrow::datatypes::DataType::Float64;
1409 use arrow::datatypes::Field;
1410 use datafusion_common::cast::{as_float64_array, as_int32_array};
1411 use datafusion_common::plan_err;
1412 use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
1413 use datafusion_expr::type_coercion::binary::comparison_coercion;
1414 use datafusion_expr_common::operator::Operator;
1415 use datafusion_physical_expr_common::physical_expr::fmt_sql;
1416 use half::f16;
1417
1418 #[test]
1419 fn case_with_expr() -> Result<()> {
1420 let batch = case_test_batch()?;
1421 let schema = batch.schema();
1422
1423 let when1 = lit("foo");
1425 let then1 = lit(123i32);
1426 let when2 = lit("bar");
1427 let then2 = lit(456i32);
1428
1429 let expr = generate_case_when_with_type_coercion(
1430 Some(col("a", &schema)?),
1431 vec![(when1, then1), (when2, then2)],
1432 None,
1433 schema.as_ref(),
1434 )?;
1435 let result = expr
1436 .evaluate(&batch)?
1437 .into_array(batch.num_rows())
1438 .expect("Failed to convert to array");
1439 let result = as_int32_array(&result)?;
1440
1441 let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1442
1443 assert_eq!(expected, result);
1444
1445 Ok(())
1446 }
1447
1448 #[test]
1449 fn case_with_expr_dictionary() -> Result<()> {
1450 let schema = Schema::new(vec![Field::new(
1451 "a",
1452 DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
1453 true,
1454 )]);
1455 let keys = UInt8Array::from(vec![0u8, 1u8, 2u8, 3u8]);
1456 let values = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
1457 let dictionary = DictionaryArray::new(keys, Arc::new(values));
1458 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?;
1459
1460 let schema = batch.schema();
1461
1462 let when1 = lit("foo");
1464 let then1 = lit(123i32);
1465 let when2 = lit("bar");
1466 let then2 = lit(456i32);
1467
1468 let expr = generate_case_when_with_type_coercion(
1469 Some(col("a", &schema)?),
1470 vec![(when1, then1), (when2, then2)],
1471 None,
1472 schema.as_ref(),
1473 )?;
1474 let result = expr
1475 .evaluate(&batch)?
1476 .into_array(batch.num_rows())
1477 .expect("Failed to convert to array");
1478 let result = as_int32_array(&result)?;
1479
1480 let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1481
1482 assert_eq!(expected, result);
1483
1484 Ok(())
1485 }
1486
1487 #[test]
1489 fn case_with_expr_primitive_dictionary() -> Result<()> {
1490 let schema = Schema::new(vec![Field::new(
1491 "a",
1492 DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::UInt64)),
1493 true,
1494 )]);
1495 let keys = UInt8Array::from(vec![0u8, 1u8, 2u8, 3u8]);
1496 let values = UInt64Array::from(vec![Some(10), Some(20), None, Some(30)]);
1497 let dictionary = DictionaryArray::new(keys, Arc::new(values));
1498 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?;
1499
1500 let schema = batch.schema();
1501
1502 let when1 = lit(10_u64);
1504 let then1 = lit(123_i32);
1505 let when2 = lit(30_u64);
1506 let then2 = lit(456_i32);
1507
1508 let expr = generate_case_when_with_type_coercion(
1509 Some(col("a", &schema)?),
1510 vec![(when1, then1), (when2, then2)],
1511 None,
1512 schema.as_ref(),
1513 )?;
1514 let result = expr
1515 .evaluate(&batch)?
1516 .into_array(batch.num_rows())
1517 .expect("Failed to convert to array");
1518 let result = as_int32_array(&result)?;
1519
1520 let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1521
1522 assert_eq!(expected, result);
1523
1524 Ok(())
1525 }
1526
1527 #[test]
1529 fn case_with_expr_boolean_dictionary() -> Result<()> {
1530 let schema = Schema::new(vec![Field::new(
1531 "a",
1532 DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Boolean)),
1533 true,
1534 )]);
1535 let keys = UInt8Array::from(vec![0u8, 1u8, 2u8, 3u8]);
1536 let values = BooleanArray::from(vec![Some(true), Some(false), None, Some(true)]);
1537 let dictionary = DictionaryArray::new(keys, Arc::new(values));
1538 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?;
1539
1540 let schema = batch.schema();
1541
1542 let when1 = lit(true);
1544 let then1 = lit(123i32);
1545 let when2 = lit(false);
1546 let then2 = lit(456i32);
1547
1548 let expr = generate_case_when_with_type_coercion(
1549 Some(col("a", &schema)?),
1550 vec![(when1, then1), (when2, then2)],
1551 None,
1552 schema.as_ref(),
1553 )?;
1554 let result = expr
1555 .evaluate(&batch)?
1556 .into_array(batch.num_rows())
1557 .expect("Failed to convert to array");
1558 let result = as_int32_array(&result)?;
1559
1560 let expected = &Int32Array::from(vec![Some(123), Some(456), None, Some(123)]);
1561
1562 assert_eq!(expected, result);
1563
1564 Ok(())
1565 }
1566
1567 #[test]
1568 fn case_with_expr_all_null_dictionary() -> Result<()> {
1569 let schema = Schema::new(vec![Field::new(
1570 "a",
1571 DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
1572 true,
1573 )]);
1574 let keys = UInt8Array::from(vec![2u8, 2u8, 2u8, 2u8]);
1575 let values = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
1576 let dictionary = DictionaryArray::new(keys, Arc::new(values));
1577 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?;
1578
1579 let schema = batch.schema();
1580
1581 let when1 = lit("foo");
1583 let then1 = lit(123i32);
1584 let when2 = lit("bar");
1585 let then2 = lit(456i32);
1586
1587 let expr = generate_case_when_with_type_coercion(
1588 Some(col("a", &schema)?),
1589 vec![(when1, then1), (when2, then2)],
1590 None,
1591 schema.as_ref(),
1592 )?;
1593 let result = expr
1594 .evaluate(&batch)?
1595 .into_array(batch.num_rows())
1596 .expect("Failed to convert to array");
1597 let result = as_int32_array(&result)?;
1598
1599 let expected = &Int32Array::from(vec![None, None, None, None]);
1600
1601 assert_eq!(expected, result);
1602
1603 Ok(())
1604 }
1605
1606 #[test]
1607 fn case_with_expr_else() -> Result<()> {
1608 let batch = case_test_batch()?;
1609 let schema = batch.schema();
1610
1611 let when1 = lit("foo");
1613 let then1 = lit(123i32);
1614 let when2 = lit("bar");
1615 let then2 = lit(456i32);
1616 let else_value = lit(999i32);
1617
1618 let expr = generate_case_when_with_type_coercion(
1619 Some(col("a", &schema)?),
1620 vec![(when1, then1), (when2, then2)],
1621 Some(else_value),
1622 schema.as_ref(),
1623 )?;
1624 let result = expr
1625 .evaluate(&batch)?
1626 .into_array(batch.num_rows())
1627 .expect("Failed to convert to array");
1628 let result = as_int32_array(&result)?;
1629
1630 let expected =
1631 &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
1632
1633 assert_eq!(expected, result);
1634
1635 Ok(())
1636 }
1637
1638 #[test]
1639 fn case_with_expr_divide_by_zero() -> Result<()> {
1640 let batch = case_test_batch1()?;
1641 let schema = batch.schema();
1642
1643 let when1 = lit(0i32);
1645 let then1 = lit(ScalarValue::Float64(None));
1646 let else_value = binary(
1647 lit(25.0f64),
1648 Operator::Divide,
1649 cast(col("a", &schema)?, &batch.schema(), Float64)?,
1650 &batch.schema(),
1651 )?;
1652
1653 let expr = generate_case_when_with_type_coercion(
1654 Some(col("a", &schema)?),
1655 vec![(when1, then1)],
1656 Some(else_value),
1657 schema.as_ref(),
1658 )?;
1659 let result = expr
1660 .evaluate(&batch)?
1661 .into_array(batch.num_rows())
1662 .expect("Failed to convert to array");
1663 let result =
1664 as_float64_array(&result).expect("failed to downcast to Float64Array");
1665
1666 let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
1667
1668 assert_eq!(expected, result);
1669
1670 Ok(())
1671 }
1672
1673 #[test]
1674 fn case_without_expr() -> Result<()> {
1675 let batch = case_test_batch()?;
1676 let schema = batch.schema();
1677
1678 let when1 = binary(
1680 col("a", &schema)?,
1681 Operator::Eq,
1682 lit("foo"),
1683 &batch.schema(),
1684 )?;
1685 let then1 = lit(123i32);
1686 let when2 = binary(
1687 col("a", &schema)?,
1688 Operator::Eq,
1689 lit("bar"),
1690 &batch.schema(),
1691 )?;
1692 let then2 = lit(456i32);
1693
1694 let expr = generate_case_when_with_type_coercion(
1695 None,
1696 vec![(when1, then1), (when2, then2)],
1697 None,
1698 schema.as_ref(),
1699 )?;
1700 let result = expr
1701 .evaluate(&batch)?
1702 .into_array(batch.num_rows())
1703 .expect("Failed to convert to array");
1704 let result = as_int32_array(&result)?;
1705
1706 let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1707
1708 assert_eq!(expected, result);
1709
1710 Ok(())
1711 }
1712
1713 #[test]
1714 fn case_with_expr_when_null() -> Result<()> {
1715 let batch = case_test_batch()?;
1716 let schema = batch.schema();
1717
1718 let when1 = lit(ScalarValue::Utf8(None));
1720 let then1 = lit(0i32);
1721 let when2 = col("a", &schema)?;
1722 let then2 = lit(123i32);
1723 let else_value = lit(999i32);
1724
1725 let expr = generate_case_when_with_type_coercion(
1726 Some(col("a", &schema)?),
1727 vec![(when1, then1), (when2, then2)],
1728 Some(else_value),
1729 schema.as_ref(),
1730 )?;
1731 let result = expr
1732 .evaluate(&batch)?
1733 .into_array(batch.num_rows())
1734 .expect("Failed to convert to array");
1735 let result = as_int32_array(&result)?;
1736
1737 let expected =
1738 &Int32Array::from(vec![Some(123), Some(123), Some(999), Some(123)]);
1739
1740 assert_eq!(expected, result);
1741
1742 Ok(())
1743 }
1744
1745 #[test]
1746 fn case_without_expr_divide_by_zero() -> Result<()> {
1747 let batch = case_test_batch1()?;
1748 let schema = batch.schema();
1749
1750 let when1 = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &batch.schema())?;
1752 let then1 = binary(
1753 lit(25.0f64),
1754 Operator::Divide,
1755 cast(col("a", &schema)?, &batch.schema(), Float64)?,
1756 &batch.schema(),
1757 )?;
1758 let x = lit(ScalarValue::Float64(None));
1759
1760 let expr = generate_case_when_with_type_coercion(
1761 None,
1762 vec![(when1, then1)],
1763 Some(x),
1764 schema.as_ref(),
1765 )?;
1766 let result = expr
1767 .evaluate(&batch)?
1768 .into_array(batch.num_rows())
1769 .expect("Failed to convert to array");
1770 let result =
1771 as_float64_array(&result).expect("failed to downcast to Float64Array");
1772
1773 let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
1774
1775 assert_eq!(expected, result);
1776
1777 Ok(())
1778 }
1779
1780 fn case_test_batch1() -> Result<RecordBatch> {
1781 let schema = Schema::new(vec![
1782 Field::new("a", DataType::Int32, true),
1783 Field::new("b", DataType::Int32, true),
1784 Field::new("c", DataType::Int32, true),
1785 ]);
1786 let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]);
1787 let b = Int32Array::from(vec![Some(3), None, Some(14), Some(7)]);
1788 let c = Int32Array::from(vec![Some(0), Some(-3), Some(777), None]);
1789 let batch = RecordBatch::try_new(
1790 Arc::new(schema),
1791 vec![Arc::new(a), Arc::new(b), Arc::new(c)],
1792 )?;
1793 Ok(batch)
1794 }
1795
1796 #[test]
1797 fn case_without_expr_else() -> Result<()> {
1798 let batch = case_test_batch()?;
1799 let schema = batch.schema();
1800
1801 let when1 = binary(
1803 col("a", &schema)?,
1804 Operator::Eq,
1805 lit("foo"),
1806 &batch.schema(),
1807 )?;
1808 let then1 = lit(123i32);
1809 let when2 = binary(
1810 col("a", &schema)?,
1811 Operator::Eq,
1812 lit("bar"),
1813 &batch.schema(),
1814 )?;
1815 let then2 = lit(456i32);
1816 let else_value = lit(999i32);
1817
1818 let expr = generate_case_when_with_type_coercion(
1819 None,
1820 vec![(when1, then1), (when2, then2)],
1821 Some(else_value),
1822 schema.as_ref(),
1823 )?;
1824 let result = expr
1825 .evaluate(&batch)?
1826 .into_array(batch.num_rows())
1827 .expect("Failed to convert to array");
1828 let result = as_int32_array(&result)?;
1829
1830 let expected =
1831 &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
1832
1833 assert_eq!(expected, result);
1834
1835 Ok(())
1836 }
1837
1838 #[test]
1839 fn case_with_type_cast() -> Result<()> {
1840 let batch = case_test_batch()?;
1841 let schema = batch.schema();
1842
1843 let when = binary(
1845 col("a", &schema)?,
1846 Operator::Eq,
1847 lit("foo"),
1848 &batch.schema(),
1849 )?;
1850 let then = lit(123.3f64);
1851 let else_value = lit(999i32);
1852
1853 let expr = generate_case_when_with_type_coercion(
1854 None,
1855 vec![(when, then)],
1856 Some(else_value),
1857 schema.as_ref(),
1858 )?;
1859 let result = expr
1860 .evaluate(&batch)?
1861 .into_array(batch.num_rows())
1862 .expect("Failed to convert to array");
1863 let result =
1864 as_float64_array(&result).expect("failed to downcast to Float64Array");
1865
1866 let expected =
1867 &Float64Array::from(vec![Some(123.3), Some(999.0), Some(999.0), Some(999.0)]);
1868
1869 assert_eq!(expected, result);
1870
1871 Ok(())
1872 }
1873
1874 #[test]
1875 fn case_with_matches_and_nulls() -> Result<()> {
1876 let batch = case_test_batch_nulls()?;
1877 let schema = batch.schema();
1878
1879 let when = binary(
1881 col("load4", &schema)?,
1882 Operator::Eq,
1883 lit(1.77f64),
1884 &batch.schema(),
1885 )?;
1886 let then = col("load4", &schema)?;
1887
1888 let expr = generate_case_when_with_type_coercion(
1889 None,
1890 vec![(when, then)],
1891 None,
1892 schema.as_ref(),
1893 )?;
1894 let result = expr
1895 .evaluate(&batch)?
1896 .into_array(batch.num_rows())
1897 .expect("Failed to convert to array");
1898 let result =
1899 as_float64_array(&result).expect("failed to downcast to Float64Array");
1900
1901 let expected =
1902 &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
1903
1904 assert_eq!(expected, result);
1905
1906 Ok(())
1907 }
1908
1909 #[test]
1910 fn case_with_scalar_predicate() -> Result<()> {
1911 let batch = case_test_batch_nulls()?;
1912 let schema = batch.schema();
1913
1914 let when = lit(true);
1916 let then = col("load4", &schema)?;
1917 let expr = generate_case_when_with_type_coercion(
1918 None,
1919 vec![(when, then)],
1920 None,
1921 schema.as_ref(),
1922 )?;
1923
1924 let result = expr
1926 .evaluate(&batch)?
1927 .into_array(batch.num_rows())
1928 .expect("Failed to convert to array");
1929 let result =
1930 as_float64_array(&result).expect("failed to downcast to Float64Array");
1931 let expected = &Float64Array::from(vec![
1932 Some(1.77),
1933 None,
1934 None,
1935 Some(1.78),
1936 None,
1937 Some(1.77),
1938 ]);
1939 assert_eq!(expected, result);
1940
1941 let expected = Float64Array::from(vec![Some(1.1)]);
1943 let batch =
1944 RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(expected.clone())])?;
1945 let result = expr
1946 .evaluate(&batch)?
1947 .into_array(batch.num_rows())
1948 .expect("Failed to convert to array");
1949 let result =
1950 as_float64_array(&result).expect("failed to downcast to Float64Array");
1951 assert_eq!(&expected, result);
1952
1953 Ok(())
1954 }
1955
1956 #[test]
1957 fn case_expr_matches_and_nulls() -> Result<()> {
1958 let batch = case_test_batch_nulls()?;
1959 let schema = batch.schema();
1960
1961 let expr = col("load4", &schema)?;
1963 let when = lit(1.77f64);
1964 let then = col("load4", &schema)?;
1965
1966 let expr = generate_case_when_with_type_coercion(
1967 Some(expr),
1968 vec![(when, then)],
1969 None,
1970 schema.as_ref(),
1971 )?;
1972 let result = expr
1973 .evaluate(&batch)?
1974 .into_array(batch.num_rows())
1975 .expect("Failed to convert to array");
1976 let result =
1977 as_float64_array(&result).expect("failed to downcast to Float64Array");
1978
1979 let expected =
1980 &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
1981
1982 assert_eq!(expected, result);
1983
1984 Ok(())
1985 }
1986
1987 #[test]
1988 fn test_when_null_and_some_cond_else_null() -> Result<()> {
1989 let batch = case_test_batch()?;
1990 let schema = batch.schema();
1991
1992 let when = binary(
1993 Arc::new(Literal::new(ScalarValue::Boolean(None))),
1994 Operator::And,
1995 binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?,
1996 &schema,
1997 )?;
1998 let then = col("a", &schema)?;
1999
2000 let expr = Arc::new(CaseExpr::try_new(None, vec![(when, then)], None)?);
2002 let result = expr
2003 .evaluate(&batch)?
2004 .into_array(batch.num_rows())
2005 .expect("Failed to convert to array");
2006 let result = as_string_array(&result);
2007
2008 assert_eq!(result.logical_null_count(), batch.num_rows());
2010 Ok(())
2011 }
2012
2013 fn case_test_batch() -> Result<RecordBatch> {
2014 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
2015 let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
2016 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
2017 Ok(batch)
2018 }
2019
2020 fn case_test_batch_nulls() -> Result<RecordBatch> {
2023 let load4: Float64Array = vec![
2024 Some(1.77), Some(1.77), Some(1.77), Some(1.78), None, Some(1.77), ]
2031 .into_iter()
2032 .collect();
2033
2034 let null_buffer = Buffer::from([0b00101001u8]);
2035 let load4 = load4
2036 .into_data()
2037 .into_builder()
2038 .null_bit_buffer(Some(null_buffer))
2039 .build()
2040 .unwrap();
2041 let load4: Float64Array = load4.into();
2042
2043 let batch =
2044 RecordBatch::try_from_iter(vec![("load4", Arc::new(load4) as ArrayRef)])?;
2045 Ok(batch)
2046 }
2047
2048 #[test]
2049 fn case_test_incompatible() -> Result<()> {
2050 let batch = case_test_batch()?;
2053 let schema = batch.schema();
2054
2055 let when1 = binary(
2057 col("a", &schema)?,
2058 Operator::Eq,
2059 lit("foo"),
2060 &batch.schema(),
2061 )?;
2062 let then1 = lit(123i32);
2063 let when2 = binary(
2064 col("a", &schema)?,
2065 Operator::Eq,
2066 lit("bar"),
2067 &batch.schema(),
2068 )?;
2069 let then2 = lit(true);
2070
2071 let expr = generate_case_when_with_type_coercion(
2072 None,
2073 vec![(when1, then1), (when2, then2)],
2074 None,
2075 schema.as_ref(),
2076 );
2077 assert!(expr.is_err());
2078
2079 let when1 = binary(
2084 col("a", &schema)?,
2085 Operator::Eq,
2086 lit("foo"),
2087 &batch.schema(),
2088 )?;
2089 let then1 = lit(123i32);
2090 let when2 = binary(
2091 col("a", &schema)?,
2092 Operator::Eq,
2093 lit("bar"),
2094 &batch.schema(),
2095 )?;
2096 let then2 = lit(456i64);
2097 let else_expr = lit(1.23f64);
2098
2099 let expr = generate_case_when_with_type_coercion(
2100 None,
2101 vec![(when1, then1), (when2, then2)],
2102 Some(else_expr),
2103 schema.as_ref(),
2104 );
2105 assert!(expr.is_ok());
2106 let result_type = expr.unwrap().data_type(schema.as_ref())?;
2107 assert_eq!(Float64, result_type);
2108 Ok(())
2109 }
2110
2111 #[test]
2112 fn case_eq() -> Result<()> {
2113 let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
2114
2115 let when1 = lit("foo");
2116 let then1 = lit(123i32);
2117 let when2 = lit("bar");
2118 let then2 = lit(456i32);
2119 let else_value = lit(999i32);
2120
2121 let expr1 = generate_case_when_with_type_coercion(
2122 Some(col("a", &schema)?),
2123 vec![
2124 (Arc::clone(&when1), Arc::clone(&then1)),
2125 (Arc::clone(&when2), Arc::clone(&then2)),
2126 ],
2127 Some(Arc::clone(&else_value)),
2128 &schema,
2129 )?;
2130
2131 let expr2 = generate_case_when_with_type_coercion(
2132 Some(col("a", &schema)?),
2133 vec![
2134 (Arc::clone(&when1), Arc::clone(&then1)),
2135 (Arc::clone(&when2), Arc::clone(&then2)),
2136 ],
2137 Some(Arc::clone(&else_value)),
2138 &schema,
2139 )?;
2140
2141 let expr3 = generate_case_when_with_type_coercion(
2142 Some(col("a", &schema)?),
2143 vec![(Arc::clone(&when1), Arc::clone(&then1)), (when2, then2)],
2144 None,
2145 &schema,
2146 )?;
2147
2148 let expr4 = generate_case_when_with_type_coercion(
2149 Some(col("a", &schema)?),
2150 vec![(when1, then1)],
2151 Some(else_value),
2152 &schema,
2153 )?;
2154
2155 assert!(expr1.eq(&expr2));
2156 assert!(expr2.eq(&expr1));
2157
2158 assert!(expr2.ne(&expr3));
2159 assert!(expr3.ne(&expr2));
2160
2161 assert!(expr1.ne(&expr4));
2162 assert!(expr4.ne(&expr1));
2163
2164 Ok(())
2165 }
2166
2167 #[test]
2168 fn case_transform() -> Result<()> {
2169 let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
2170
2171 let when1 = lit("foo");
2172 let then1 = lit(123i32);
2173 let when2 = lit("bar");
2174 let then2 = lit(456i32);
2175 let else_value = lit(999i32);
2176
2177 let expr = generate_case_when_with_type_coercion(
2178 Some(col("a", &schema)?),
2179 vec![
2180 (Arc::clone(&when1), Arc::clone(&then1)),
2181 (Arc::clone(&when2), Arc::clone(&then2)),
2182 ],
2183 Some(Arc::clone(&else_value)),
2184 &schema,
2185 )?;
2186
2187 let expr2 = Arc::clone(&expr)
2188 .transform(|e| {
2189 let transformed = match e.as_any().downcast_ref::<Literal>() {
2190 Some(lit_value) => match lit_value.value() {
2191 ScalarValue::Utf8(Some(str_value)) => {
2192 Some(lit(str_value.to_uppercase()))
2193 }
2194 _ => None,
2195 },
2196 _ => None,
2197 };
2198 Ok(if let Some(transformed) = transformed {
2199 Transformed::yes(transformed)
2200 } else {
2201 Transformed::no(e)
2202 })
2203 })
2204 .data()
2205 .unwrap();
2206
2207 let expr3 = Arc::clone(&expr)
2208 .transform_down(|e| {
2209 let transformed = match e.as_any().downcast_ref::<Literal>() {
2210 Some(lit_value) => match lit_value.value() {
2211 ScalarValue::Utf8(Some(str_value)) => {
2212 Some(lit(str_value.to_uppercase()))
2213 }
2214 _ => None,
2215 },
2216 _ => None,
2217 };
2218 Ok(if let Some(transformed) = transformed {
2219 Transformed::yes(transformed)
2220 } else {
2221 Transformed::no(e)
2222 })
2223 })
2224 .data()
2225 .unwrap();
2226
2227 assert!(expr.ne(&expr2));
2228 assert!(expr2.eq(&expr3));
2229
2230 Ok(())
2231 }
2232
2233 #[test]
2234 fn test_column_or_null_specialization() -> Result<()> {
2235 let mut c1 = Int32Builder::new();
2237 let mut c2 = StringBuilder::new();
2238 for i in 0..1000 {
2239 c1.append_value(i);
2240 if i % 7 == 0 {
2241 c2.append_null();
2242 } else {
2243 c2.append_value(format!("string {i}"));
2244 }
2245 }
2246 let c1 = Arc::new(c1.finish());
2247 let c2 = Arc::new(c2.finish());
2248 let schema = Schema::new(vec![
2249 Field::new("c1", DataType::Int32, true),
2250 Field::new("c2", DataType::Utf8, true),
2251 ]);
2252 let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap();
2253
2254 let predicate = Arc::new(BinaryExpr::new(
2256 make_col("c1", 0),
2257 Operator::LtEq,
2258 make_lit_i32(250),
2259 ));
2260 let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?;
2261 assert!(matches!(expr.eval_method, EvalMethod::InfallibleExprOrNull));
2262 match expr.evaluate(&batch)? {
2263 ColumnarValue::Array(array) => {
2264 assert_eq!(1000, array.len());
2265 assert_eq!(785, array.null_count());
2266 }
2267 _ => unreachable!(),
2268 }
2269 Ok(())
2270 }
2271
2272 #[test]
2273 fn test_expr_or_expr_specialization() -> Result<()> {
2274 let batch = case_test_batch1()?;
2275 let schema = batch.schema();
2276 let when = binary(
2277 col("a", &schema)?,
2278 Operator::LtEq,
2279 lit(2i32),
2280 &batch.schema(),
2281 )?;
2282 let then = col("b", &schema)?;
2283 let else_expr = col("c", &schema)?;
2284 let expr = CaseExpr::try_new(None, vec![(when, then)], Some(else_expr))?;
2285 assert!(matches!(
2286 expr.eval_method,
2287 EvalMethod::ExpressionOrExpression(_)
2288 ));
2289 let result = expr
2290 .evaluate(&batch)?
2291 .into_array(batch.num_rows())
2292 .expect("Failed to convert to array");
2293 let result = as_int32_array(&result).expect("failed to downcast to Int32Array");
2294
2295 let expected = &Int32Array::from(vec![Some(3), None, Some(777), None]);
2296
2297 assert_eq!(expected, result);
2298 Ok(())
2299 }
2300
2301 fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
2302 Arc::new(Column::new(name, index))
2303 }
2304
2305 fn make_lit_i32(n: i32) -> Arc<dyn PhysicalExpr> {
2306 Arc::new(Literal::new(ScalarValue::Int32(Some(n))))
2307 }
2308
2309 fn generate_case_when_with_type_coercion(
2310 expr: Option<Arc<dyn PhysicalExpr>>,
2311 when_thens: Vec<WhenThen>,
2312 else_expr: Option<Arc<dyn PhysicalExpr>>,
2313 input_schema: &Schema,
2314 ) -> Result<Arc<dyn PhysicalExpr>> {
2315 let coerce_type =
2316 get_case_common_type(&when_thens, else_expr.clone(), input_schema);
2317 let (when_thens, else_expr) = match coerce_type {
2318 None => plan_err!(
2319 "Can't get a common type for then {when_thens:?} and else {else_expr:?} expression"
2320 ),
2321 Some(data_type) => {
2322 let left = when_thens
2324 .into_iter()
2325 .map(|(when, then)| {
2326 let then = try_cast(then, input_schema, data_type.clone())?;
2327 Ok((when, then))
2328 })
2329 .collect::<Result<Vec<_>>>()?;
2330 let right = match else_expr {
2331 None => None,
2332 Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?),
2333 };
2334
2335 Ok((left, right))
2336 }
2337 }?;
2338 case(expr, when_thens, else_expr)
2339 }
2340
2341 fn get_case_common_type(
2342 when_thens: &[WhenThen],
2343 else_expr: Option<Arc<dyn PhysicalExpr>>,
2344 input_schema: &Schema,
2345 ) -> Option<DataType> {
2346 let thens_type = when_thens
2347 .iter()
2348 .map(|when_then| {
2349 let data_type = &when_then.1.data_type(input_schema).unwrap();
2350 data_type.clone()
2351 })
2352 .collect::<Vec<_>>();
2353 let else_type = match else_expr {
2354 None => {
2355 thens_type[0].clone()
2357 }
2358 Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(),
2359 };
2360 thens_type
2361 .iter()
2362 .try_fold(else_type, |left_type, right_type| {
2363 comparison_coercion(&left_type, right_type)
2366 })
2367 }
2368
2369 #[test]
2370 fn test_fmt_sql() -> Result<()> {
2371 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
2372
2373 let when = binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?;
2375 let then = lit(123.3f64);
2376 let else_value = lit(999i32);
2377
2378 let expr = generate_case_when_with_type_coercion(
2379 None,
2380 vec![(when, then)],
2381 Some(else_value),
2382 &schema,
2383 )?;
2384
2385 let display_string = expr.to_string();
2386 assert_eq!(
2387 display_string,
2388 "CASE WHEN a@0 = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END"
2389 );
2390
2391 let sql_string = fmt_sql(expr.as_ref()).to_string();
2392 assert_eq!(
2393 sql_string,
2394 "CASE WHEN a = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END"
2395 );
2396
2397 Ok(())
2398 }
2399
2400 fn when_then_else(
2401 when: &Arc<dyn PhysicalExpr>,
2402 then: &Arc<dyn PhysicalExpr>,
2403 els: &Arc<dyn PhysicalExpr>,
2404 ) -> Result<Arc<dyn PhysicalExpr>> {
2405 let case = CaseExpr::try_new(
2406 None,
2407 vec![(Arc::clone(when), Arc::clone(then))],
2408 Some(Arc::clone(els)),
2409 )?;
2410 Ok(Arc::new(case))
2411 }
2412
2413 #[test]
2414 fn test_case_expression_nullability_with_nullable_column() -> Result<()> {
2415 case_expression_nullability(true)
2416 }
2417
2418 #[test]
2419 fn test_case_expression_nullability_with_not_nullable_column() -> Result<()> {
2420 case_expression_nullability(false)
2421 }
2422
2423 fn case_expression_nullability(col_is_nullable: bool) -> Result<()> {
2424 let schema =
2425 Schema::new(vec![Field::new("foo", DataType::Int32, col_is_nullable)]);
2426
2427 let foo = col("foo", &schema)?;
2428 let foo_is_not_null = is_not_null(Arc::clone(&foo))?;
2429 let foo_is_null = expressions::is_null(Arc::clone(&foo))?;
2430 let not_foo_is_null = expressions::not(Arc::clone(&foo_is_null))?;
2431 let zero = lit(0);
2432 let foo_eq_zero =
2433 binary(Arc::clone(&foo), Operator::Eq, Arc::clone(&zero), &schema)?;
2434
2435 assert_not_nullable(when_then_else(&foo_is_not_null, &foo, &zero)?, &schema);
2436 assert_not_nullable(when_then_else(¬_foo_is_null, &foo, &zero)?, &schema);
2437 assert_not_nullable(when_then_else(&foo_eq_zero, &foo, &zero)?, &schema);
2438
2439 assert_not_nullable(
2440 when_then_else(
2441 &binary(
2442 Arc::clone(&foo_is_not_null),
2443 Operator::And,
2444 Arc::clone(&foo_eq_zero),
2445 &schema,
2446 )?,
2447 &foo,
2448 &zero,
2449 )?,
2450 &schema,
2451 );
2452
2453 assert_not_nullable(
2454 when_then_else(
2455 &binary(
2456 Arc::clone(&foo_eq_zero),
2457 Operator::And,
2458 Arc::clone(&foo_is_not_null),
2459 &schema,
2460 )?,
2461 &foo,
2462 &zero,
2463 )?,
2464 &schema,
2465 );
2466
2467 assert_not_nullable(
2468 when_then_else(
2469 &binary(
2470 Arc::clone(&foo_is_not_null),
2471 Operator::Or,
2472 Arc::clone(&foo_eq_zero),
2473 &schema,
2474 )?,
2475 &foo,
2476 &zero,
2477 )?,
2478 &schema,
2479 );
2480
2481 assert_not_nullable(
2482 when_then_else(
2483 &binary(
2484 Arc::clone(&foo_eq_zero),
2485 Operator::Or,
2486 Arc::clone(&foo_is_not_null),
2487 &schema,
2488 )?,
2489 &foo,
2490 &zero,
2491 )?,
2492 &schema,
2493 );
2494
2495 assert_nullability(
2496 when_then_else(
2497 &binary(
2498 Arc::clone(&foo_is_null),
2499 Operator::Or,
2500 Arc::clone(&foo_eq_zero),
2501 &schema,
2502 )?,
2503 &foo,
2504 &zero,
2505 )?,
2506 &schema,
2507 col_is_nullable,
2508 );
2509
2510 assert_nullability(
2511 when_then_else(
2512 &binary(
2513 binary(Arc::clone(&foo), Operator::Eq, Arc::clone(&zero), &schema)?,
2514 Operator::Or,
2515 Arc::clone(&foo_is_null),
2516 &schema,
2517 )?,
2518 &foo,
2519 &zero,
2520 )?,
2521 &schema,
2522 col_is_nullable,
2523 );
2524
2525 assert_not_nullable(
2526 when_then_else(
2527 &binary(
2528 binary(
2529 binary(
2530 Arc::clone(&foo),
2531 Operator::Eq,
2532 Arc::clone(&zero),
2533 &schema,
2534 )?,
2535 Operator::And,
2536 Arc::clone(&foo_is_not_null),
2537 &schema,
2538 )?,
2539 Operator::Or,
2540 binary(
2541 binary(
2542 Arc::clone(&foo),
2543 Operator::Eq,
2544 Arc::clone(&foo),
2545 &schema,
2546 )?,
2547 Operator::And,
2548 Arc::clone(&foo_is_not_null),
2549 &schema,
2550 )?,
2551 &schema,
2552 )?,
2553 &foo,
2554 &zero,
2555 )?,
2556 &schema,
2557 );
2558
2559 Ok(())
2560 }
2561
2562 fn assert_not_nullable(expr: Arc<dyn PhysicalExpr>, schema: &Schema) {
2563 assert!(!expr.nullable(schema).unwrap());
2564 }
2565
2566 fn assert_nullable(expr: Arc<dyn PhysicalExpr>, schema: &Schema) {
2567 assert!(expr.nullable(schema).unwrap());
2568 }
2569
2570 fn assert_nullability(expr: Arc<dyn PhysicalExpr>, schema: &Schema, nullable: bool) {
2571 if nullable {
2572 assert_nullable(expr, schema);
2573 } else {
2574 assert_not_nullable(expr, schema);
2575 }
2576 }
2577
2578 fn test_case_when_literal_lookup(
2581 values: ArrayRef,
2582 lookup_map: &[(ScalarValue, ScalarValue)],
2583 else_value: Option<ScalarValue>,
2584 expected: ArrayRef,
2585 ) {
2586 let schema = Schema::new(vec![Field::new(
2593 "a",
2594 values.data_type().clone(),
2595 values.is_nullable(),
2596 )]);
2597 let schema = Arc::new(schema);
2598
2599 let batch = RecordBatch::try_new(schema, vec![values])
2600 .expect("failed to create RecordBatch");
2601
2602 let schema = batch.schema_ref();
2603 let case = col("a", schema).expect("failed to create col");
2604
2605 let when_then = lookup_map
2606 .iter()
2607 .map(|(when, then)| {
2608 (
2609 Arc::new(Literal::new(when.clone())) as _,
2610 Arc::new(Literal::new(then.clone())) as _,
2611 )
2612 })
2613 .collect::<Vec<WhenThen>>();
2614
2615 let else_expr = else_value.map(|else_value| {
2616 Arc::new(Literal::new(else_value)) as Arc<dyn PhysicalExpr>
2617 });
2618 let expr = CaseExpr::try_new(Some(case), when_then, else_expr)
2619 .expect("failed to create case");
2620
2621 assert!(
2623 matches!(
2624 expr.eval_method,
2625 EvalMethod::WithExprScalarLookupTable { .. }
2626 ),
2627 "we should use the expected eval method"
2628 );
2629
2630 let actual = expr
2631 .evaluate(&batch)
2632 .expect("failed to evaluate case")
2633 .into_array(batch.num_rows())
2634 .expect("Failed to convert to array");
2635
2636 assert_eq!(
2637 actual.data_type(),
2638 expected.data_type(),
2639 "Data type mismatch"
2640 );
2641
2642 assert_eq!(
2643 actual.as_ref(),
2644 expected.as_ref(),
2645 "actual (left) does not match expected (right)"
2646 );
2647 }
2648
2649 fn create_lookup<When, Then>(
2650 when_then_pairs: impl IntoIterator<Item = (When, Then)>,
2651 ) -> Vec<(ScalarValue, ScalarValue)>
2652 where
2653 ScalarValue: From<When>,
2654 ScalarValue: From<Then>,
2655 {
2656 when_then_pairs
2657 .into_iter()
2658 .map(|(when, then)| (ScalarValue::from(when), ScalarValue::from(then)))
2659 .collect()
2660 }
2661
2662 fn create_input_and_expected<Input, Expected, InputFromItem, ExpectedFromItem>(
2663 input_and_expected_pairs: impl IntoIterator<Item = (InputFromItem, ExpectedFromItem)>,
2664 ) -> (Input, Expected)
2665 where
2666 Input: Array + From<Vec<InputFromItem>>,
2667 Expected: Array + From<Vec<ExpectedFromItem>>,
2668 {
2669 let (input_items, expected_items): (Vec<InputFromItem>, Vec<ExpectedFromItem>) =
2670 input_and_expected_pairs.into_iter().unzip();
2671
2672 (Input::from(input_items), Expected::from(expected_items))
2673 }
2674
2675 fn test_lookup_eval_with_and_without_else(
2676 lookup_map: &[(ScalarValue, ScalarValue)],
2677 input_values: ArrayRef,
2678 expected: StringArray,
2679 ) {
2680 test_case_when_literal_lookup(
2682 Arc::clone(&input_values),
2683 lookup_map,
2684 None,
2685 Arc::new(expected.clone()),
2686 );
2687
2688 let else_value = "___fallback___";
2690
2691 let expected_with_else = expected
2693 .iter()
2694 .map(|item| item.unwrap_or(else_value))
2695 .map(Some)
2696 .collect::<StringArray>();
2697
2698 test_case_when_literal_lookup(
2700 input_values,
2701 lookup_map,
2702 Some(ScalarValue::Utf8(Some(else_value.to_string()))),
2703 Arc::new(expected_with_else),
2704 );
2705 }
2706
2707 #[test]
2708 fn test_case_when_literal_lookup_int32_to_string() {
2709 let lookup_map = create_lookup([
2710 (Some(4), Some("four")),
2711 (Some(2), Some("two")),
2712 (Some(3), Some("three")),
2713 (Some(1), Some("one")),
2714 ]);
2715
2716 let (input_values, expected) =
2717 create_input_and_expected::<Int32Array, StringArray, _, _>([
2718 (1, Some("one")),
2719 (2, Some("two")),
2720 (3, Some("three")),
2721 (3, Some("three")),
2722 (2, Some("two")),
2723 (3, Some("three")),
2724 (5, None), (5, None), (3, Some("three")),
2727 (5, None), ]);
2729
2730 test_lookup_eval_with_and_without_else(
2731 &lookup_map,
2732 Arc::new(input_values),
2733 expected,
2734 );
2735 }
2736
2737 #[test]
2738 fn test_case_when_literal_lookup_none_case_should_never_match() {
2739 let lookup_map = create_lookup([
2740 (Some(4), Some("four")),
2741 (None, Some("none")),
2742 (Some(2), Some("two")),
2743 (Some(1), Some("one")),
2744 ]);
2745
2746 let (input_values, expected) =
2747 create_input_and_expected::<Int32Array, StringArray, _, _>([
2748 (Some(1), Some("one")),
2749 (Some(5), None), (None, None), (Some(2), Some("two")),
2752 (None, None), (None, None), (Some(2), Some("two")),
2755 (Some(5), None), ]);
2757
2758 test_lookup_eval_with_and_without_else(
2759 &lookup_map,
2760 Arc::new(input_values),
2761 expected,
2762 );
2763 }
2764
2765 #[test]
2766 fn test_case_when_literal_lookup_int32_to_string_with_duplicate_cases() {
2767 let lookup_map = create_lookup([
2768 (Some(4), Some("four")),
2769 (Some(4), Some("no 4")),
2770 (Some(2), Some("two")),
2771 (Some(2), Some("no 2")),
2772 (Some(3), Some("three")),
2773 (Some(3), Some("no 3")),
2774 (Some(2), Some("no 2")),
2775 (Some(4), Some("no 4")),
2776 (Some(2), Some("no 2")),
2777 (Some(3), Some("no 3")),
2778 (Some(4), Some("no 4")),
2779 (Some(2), Some("no 2")),
2780 (Some(3), Some("no 3")),
2781 (Some(3), Some("no 3")),
2782 ]);
2783
2784 let (input_values, expected) =
2785 create_input_and_expected::<Int32Array, StringArray, _, _>([
2786 (1, None), (2, Some("two")),
2788 (3, Some("three")),
2789 (3, Some("three")),
2790 (2, Some("two")),
2791 (3, Some("three")),
2792 (5, None), (5, None), (3, Some("three")),
2795 (5, None), ]);
2797
2798 test_lookup_eval_with_and_without_else(
2799 &lookup_map,
2800 Arc::new(input_values),
2801 expected,
2802 );
2803 }
2804
2805 #[test]
2806 fn test_case_when_literal_lookup_f32_to_string_with_special_values_and_duplicate_cases()
2807 {
2808 let lookup_map = create_lookup([
2809 (Some(4.0), Some("four point zero")),
2810 (Some(f32::NAN), Some("NaN")),
2811 (Some(3.2), Some("three point two")),
2812 (Some(f32::NAN), Some("should not use this NaN branch")),
2814 (Some(f32::INFINITY), Some("Infinity")),
2815 (Some(0.0), Some("zero")),
2816 (
2818 Some(f32::INFINITY),
2819 Some("should not use this Infinity branch"),
2820 ),
2821 (Some(1.1), Some("one point one")),
2822 ]);
2823
2824 let (input_values, expected) =
2825 create_input_and_expected::<Float32Array, StringArray, _, _>([
2826 (1.1, Some("one point one")),
2827 (f32::NAN, Some("NaN")),
2828 (3.2, Some("three point two")),
2829 (3.2, Some("three point two")),
2830 (0.0, Some("zero")),
2831 (f32::INFINITY, Some("Infinity")),
2832 (3.2, Some("three point two")),
2833 (f32::NEG_INFINITY, None), (f32::NEG_INFINITY, None), (3.2, Some("three point two")),
2836 (-0.0, None), ]);
2838
2839 test_lookup_eval_with_and_without_else(
2840 &lookup_map,
2841 Arc::new(input_values),
2842 expected,
2843 );
2844 }
2845
2846 #[test]
2847 fn test_case_when_literal_lookup_f16_to_string_with_special_values() {
2848 let lookup_map = create_lookup([
2849 (
2850 ScalarValue::Float16(Some(f16::from_f32(3.2))),
2851 Some("3 dot 2"),
2852 ),
2853 (ScalarValue::Float16(Some(f16::NAN)), Some("NaN")),
2854 (
2855 ScalarValue::Float16(Some(f16::from_f32(17.4))),
2856 Some("17 dot 4"),
2857 ),
2858 (ScalarValue::Float16(Some(f16::INFINITY)), Some("Infinity")),
2859 (ScalarValue::Float16(Some(f16::ZERO)), Some("zero")),
2860 ]);
2861
2862 let (input_values, expected) =
2863 create_input_and_expected::<Float16Array, StringArray, _, _>([
2864 (f16::from_f32(3.2), Some("3 dot 2")),
2865 (f16::NAN, Some("NaN")),
2866 (f16::from_f32(17.4), Some("17 dot 4")),
2867 (f16::from_f32(17.4), Some("17 dot 4")),
2868 (f16::INFINITY, Some("Infinity")),
2869 (f16::from_f32(17.4), Some("17 dot 4")),
2870 (f16::NEG_INFINITY, None), (f16::NEG_INFINITY, None), (f16::from_f32(17.4), Some("17 dot 4")),
2873 (f16::NEG_ZERO, None), ]);
2875
2876 test_lookup_eval_with_and_without_else(
2877 &lookup_map,
2878 Arc::new(input_values),
2879 expected,
2880 );
2881 }
2882
2883 #[test]
2884 fn test_case_when_literal_lookup_f32_to_string_with_special_values() {
2885 let lookup_map = create_lookup([
2886 (3.2, Some("3 dot 2")),
2887 (f32::NAN, Some("NaN")),
2888 (17.4, Some("17 dot 4")),
2889 (f32::INFINITY, Some("Infinity")),
2890 (f32::ZERO, Some("zero")),
2891 ]);
2892
2893 let (input_values, expected) =
2894 create_input_and_expected::<Float32Array, StringArray, _, _>([
2895 (3.2, Some("3 dot 2")),
2896 (f32::NAN, Some("NaN")),
2897 (17.4, Some("17 dot 4")),
2898 (17.4, Some("17 dot 4")),
2899 (f32::INFINITY, Some("Infinity")),
2900 (17.4, Some("17 dot 4")),
2901 (f32::NEG_INFINITY, None), (f32::NEG_INFINITY, None), (17.4, Some("17 dot 4")),
2904 (-0.0, None), ]);
2906
2907 test_lookup_eval_with_and_without_else(
2908 &lookup_map,
2909 Arc::new(input_values),
2910 expected,
2911 );
2912 }
2913
2914 #[test]
2915 fn test_case_when_literal_lookup_f64_to_string_with_special_values() {
2916 let lookup_map = create_lookup([
2917 (3.2, Some("3 dot 2")),
2918 (f64::NAN, Some("NaN")),
2919 (17.4, Some("17 dot 4")),
2920 (f64::INFINITY, Some("Infinity")),
2921 (f64::ZERO, Some("zero")),
2922 ]);
2923
2924 let (input_values, expected) =
2925 create_input_and_expected::<Float64Array, StringArray, _, _>([
2926 (3.2, Some("3 dot 2")),
2927 (f64::NAN, Some("NaN")),
2928 (17.4, Some("17 dot 4")),
2929 (17.4, Some("17 dot 4")),
2930 (f64::INFINITY, Some("Infinity")),
2931 (17.4, Some("17 dot 4")),
2932 (f64::NEG_INFINITY, None), (f64::NEG_INFINITY, None), (17.4, Some("17 dot 4")),
2935 (-0.0, None), ]);
2937
2938 test_lookup_eval_with_and_without_else(
2939 &lookup_map,
2940 Arc::new(input_values),
2941 expected,
2942 );
2943 }
2944
2945 #[test]
2947 fn test_decimal_with_non_default_precision_and_scale() {
2948 let lookup_map = create_lookup([
2949 (ScalarValue::Decimal32(Some(4), 3, 2), Some("four")),
2950 (ScalarValue::Decimal32(Some(2), 3, 2), Some("two")),
2951 (ScalarValue::Decimal32(Some(3), 3, 2), Some("three")),
2952 (ScalarValue::Decimal32(Some(1), 3, 2), Some("one")),
2953 ]);
2954
2955 let (input_values, expected) =
2956 create_input_and_expected::<Decimal32Array, StringArray, _, _>([
2957 (1, Some("one")),
2958 (2, Some("two")),
2959 (3, Some("three")),
2960 (3, Some("three")),
2961 (2, Some("two")),
2962 (3, Some("three")),
2963 (5, None), (5, None), (3, Some("three")),
2966 (5, None), ]);
2968
2969 let input_values = input_values
2970 .with_precision_and_scale(3, 2)
2971 .expect("must be able to set precision and scale");
2972
2973 test_lookup_eval_with_and_without_else(
2974 &lookup_map,
2975 Arc::new(input_values),
2976 expected,
2977 );
2978 }
2979
2980 #[test]
2982 fn test_timestamp_with_non_default_timezone() {
2983 let timezone: Option<Arc<str>> = Some("-10:00".into());
2984 let lookup_map = create_lookup([
2985 (
2986 ScalarValue::TimestampMillisecond(Some(4), timezone.clone()),
2987 Some("four"),
2988 ),
2989 (
2990 ScalarValue::TimestampMillisecond(Some(2), timezone.clone()),
2991 Some("two"),
2992 ),
2993 (
2994 ScalarValue::TimestampMillisecond(Some(3), timezone.clone()),
2995 Some("three"),
2996 ),
2997 (
2998 ScalarValue::TimestampMillisecond(Some(1), timezone.clone()),
2999 Some("one"),
3000 ),
3001 ]);
3002
3003 let (input_values, expected) =
3004 create_input_and_expected::<TimestampMillisecondArray, StringArray, _, _>([
3005 (1, Some("one")),
3006 (2, Some("two")),
3007 (3, Some("three")),
3008 (3, Some("three")),
3009 (2, Some("two")),
3010 (3, Some("three")),
3011 (5, None), (5, None), (3, Some("three")),
3014 (5, None), ]);
3016
3017 let input_values = input_values.with_timezone_opt(timezone);
3018
3019 test_lookup_eval_with_and_without_else(
3020 &lookup_map,
3021 Arc::new(input_values),
3022 expected,
3023 );
3024 }
3025
3026 #[test]
3027 fn test_with_strings_to_int32() {
3028 let lookup_map = create_lookup([
3029 (Some("why"), Some(42)),
3030 (Some("what"), Some(22)),
3031 (Some("when"), Some(17)),
3032 ]);
3033
3034 let (input_values, expected) =
3035 create_input_and_expected::<StringArray, Int32Array, _, _>([
3036 (Some("why"), Some(42)),
3037 (Some("5"), None), (None, None), (Some("what"), Some(22)),
3040 (None, None), (None, None), (Some("what"), Some(22)),
3043 (Some("5"), None), ]);
3045
3046 let input_values = Arc::new(input_values) as ArrayRef;
3047
3048 test_case_when_literal_lookup(
3050 Arc::clone(&input_values),
3051 &lookup_map,
3052 None,
3053 Arc::new(expected.clone()),
3054 );
3055
3056 let else_value = 101;
3058
3059 let expected_with_else = expected
3061 .iter()
3062 .map(|item| item.unwrap_or(else_value))
3063 .map(Some)
3064 .collect::<Int32Array>();
3065
3066 test_case_when_literal_lookup(
3068 input_values,
3069 &lookup_map,
3070 Some(ScalarValue::Int32(Some(else_value))),
3071 Arc::new(expected_with_else),
3072 );
3073 }
3074}