1mod literal_lookup_table;
19
20use super::{Column, Literal};
21use crate::PhysicalExpr;
22use crate::expressions::{lit, try_cast};
23use arrow::array::*;
24use arrow::compute::kernels::zip::zip;
25use arrow::compute::{
26 FilterBuilder, FilterPredicate, is_not_null, not, nullif, prep_null_mask_filter,
27};
28use arrow::datatypes::{DataType, Schema, UInt32Type, UnionMode};
29use arrow::error::ArrowError;
30use datafusion_common::cast::as_boolean_array;
31use datafusion_common::{
32 DataFusionError, Result, ScalarValue, assert_or_internal_err, exec_err,
33 internal_datafusion_err, internal_err,
34};
35use datafusion_expr::ColumnarValue;
36use indexmap::{IndexMap, IndexSet};
37use std::borrow::Cow;
38use std::hash::Hash;
39use std::{any::Any, sync::Arc};
40
41use crate::expressions::case::literal_lookup_table::LiteralLookupTable;
42use arrow::compute::kernels::merge::{MergeIndex, merge, merge_n};
43use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
44use datafusion_physical_expr_common::datum::compare_with_eq;
45use datafusion_physical_expr_common::utils::scatter;
46use itertools::Itertools;
47use std::fmt::{Debug, Formatter};
48
49pub(super) type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
50
51#[derive(Debug, Hash, PartialEq, Eq)]
52enum EvalMethod {
53 NoExpression(ProjectedCaseBody),
58 WithExpression(ProjectedCaseBody),
64 InfallibleExprOrNull,
70 ScalarOrScalar,
75 ExpressionOrExpression(ProjectedCaseBody),
84
85 WithExprScalarLookupTable(LiteralLookupTable),
89}
90
91impl Hash for LiteralLookupTable {
98 fn hash<H: std::hash::Hasher>(&self, _state: &mut H) {}
99}
100
101impl PartialEq for LiteralLookupTable {
108 fn eq(&self, _other: &Self) -> bool {
109 true
110 }
111}
112
113impl Eq for LiteralLookupTable {}
114
115#[derive(Debug, Hash, PartialEq, Eq)]
118struct CaseBody {
119 expr: Option<Arc<dyn PhysicalExpr>>,
121 when_then_expr: Vec<WhenThen>,
123 else_expr: Option<Arc<dyn PhysicalExpr>>,
125}
126
127impl CaseBody {
128 fn project(&self) -> Result<ProjectedCaseBody> {
130 let mut used_column_indices = IndexSet::<usize>::new();
132 let mut collect_column_indices = |expr: &Arc<dyn PhysicalExpr>| {
133 expr.apply(|expr| {
134 if let Some(column) = expr.as_any().downcast_ref::<Column>() {
135 used_column_indices.insert(column.index());
136 }
137 Ok(TreeNodeRecursion::Continue)
138 })
139 .expect("Closure cannot fail");
140 };
141
142 if let Some(e) = &self.expr {
143 collect_column_indices(e);
144 }
145 self.when_then_expr.iter().for_each(|(w, t)| {
146 collect_column_indices(w);
147 collect_column_indices(t);
148 });
149 if let Some(e) = &self.else_expr {
150 collect_column_indices(e);
151 }
152
153 let column_index_map = used_column_indices
155 .iter()
156 .enumerate()
157 .map(|(projected, original)| (*original, projected))
158 .collect::<IndexMap<usize, usize>>();
159
160 let project = |expr: &Arc<dyn PhysicalExpr>| -> Result<Arc<dyn PhysicalExpr>> {
163 Arc::clone(expr)
164 .transform_down(|e| {
165 if let Some(column) = e.as_any().downcast_ref::<Column>() {
166 let original = column.index();
167 let projected = *column_index_map.get(&original).unwrap();
168 if projected != original {
169 return Ok(Transformed::yes(Arc::new(Column::new(
170 column.name(),
171 projected,
172 ))));
173 }
174 }
175 Ok(Transformed::no(e))
176 })
177 .map(|t| t.data)
178 };
179
180 let projected_body = CaseBody {
181 expr: self.expr.as_ref().map(project).transpose()?,
182 when_then_expr: self
183 .when_then_expr
184 .iter()
185 .map(|(e, t)| Ok((project(e)?, project(t)?)))
186 .collect::<Result<Vec<_>>>()?,
187 else_expr: self.else_expr.as_ref().map(project).transpose()?,
188 };
189
190 let projection = column_index_map
192 .iter()
193 .sorted_by_key(|(_, v)| **v)
194 .map(|(k, _)| *k)
195 .collect::<Vec<_>>();
196
197 Ok(ProjectedCaseBody {
198 projection,
199 body: projected_body,
200 })
201 }
202}
203
204#[derive(Debug, Hash, PartialEq, Eq)]
232struct ProjectedCaseBody {
233 projection: Vec<usize>,
234 body: CaseBody,
235}
236
237#[derive(Debug, Hash, PartialEq, Eq)]
255pub struct CaseExpr {
256 body: CaseBody,
258 eval_method: EvalMethod,
260}
261
262impl std::fmt::Display for CaseExpr {
263 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
264 write!(f, "CASE ")?;
265 if let Some(e) = &self.body.expr {
266 write!(f, "{e} ")?;
267 }
268 for (w, t) in &self.body.when_then_expr {
269 write!(f, "WHEN {w} THEN {t} ")?;
270 }
271 if let Some(e) = &self.body.else_expr {
272 write!(f, "ELSE {e} ")?;
273 }
274 write!(f, "END")
275 }
276}
277
278fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) -> bool {
284 expr.as_any().is::<Column>()
285}
286
287fn create_filter(predicate: &BooleanArray, optimize: bool) -> FilterPredicate {
289 let mut filter_builder = FilterBuilder::new(predicate);
290 if optimize {
291 filter_builder = filter_builder.optimize();
293 }
294 filter_builder.build()
295}
296
297fn multiple_arrays(data_type: &DataType) -> bool {
298 match data_type {
299 DataType::Struct(fields) => {
300 fields.len() > 1
301 || fields.len() == 1 && multiple_arrays(fields[0].data_type())
302 }
303 DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
304 _ => false,
305 }
306}
307
308fn filter_record_batch(
311 record_batch: &RecordBatch,
312 filter: &FilterPredicate,
313) -> std::result::Result<RecordBatch, ArrowError> {
314 let filtered_columns = record_batch
315 .columns()
316 .iter()
317 .map(|a| filter_array(a, filter))
318 .collect::<std::result::Result<Vec<_>, _>>()?;
319 unsafe {
325 Ok(RecordBatch::new_unchecked(
326 record_batch.schema(),
327 filtered_columns,
328 filter.count(),
329 ))
330 }
331}
332
333#[inline(always)]
338fn filter_array(
339 array: &dyn Array,
340 filter: &FilterPredicate,
341) -> std::result::Result<ArrayRef, ArrowError> {
342 filter.filter(array)
343}
344
345#[derive(Copy, Clone, PartialEq, Eq)]
350struct PartialResultIndex {
351 index: u32,
352}
353
354const NONE_VALUE: u32 = u32::MAX;
355
356impl PartialResultIndex {
357 fn none() -> Self {
359 Self { index: NONE_VALUE }
360 }
361
362 fn zero() -> Self {
363 Self { index: 0 }
364 }
365
366 fn try_new(index: usize) -> Result<Self> {
371 let Ok(index) = u32::try_from(index) else {
372 return internal_err!("Partial result index exceeds limit");
373 };
374
375 assert_or_internal_err!(
376 index != NONE_VALUE,
377 "Partial result index exceeds limit"
378 );
379
380 Ok(Self { index })
381 }
382
383 fn is_none(&self) -> bool {
385 self.index == NONE_VALUE
386 }
387}
388
389impl MergeIndex for PartialResultIndex {
390 fn index(&self) -> Option<usize> {
392 if self.is_none() {
393 None
394 } else {
395 Some(self.index as usize)
396 }
397 }
398}
399
400impl Debug for PartialResultIndex {
401 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
402 if self.is_none() {
403 write!(f, "null")
404 } else {
405 write!(f, "{}", self.index)
406 }
407 }
408}
409
410enum ResultState {
411 Empty,
413 Partial {
415 arrays: Vec<ArrayRef>,
418 indices: Vec<PartialResultIndex>,
420 },
421 Complete(ColumnarValue),
424}
425
426struct ResultBuilder {
436 data_type: DataType,
437 row_count: usize,
439 state: ResultState,
440}
441
442impl ResultBuilder {
443 fn new(data_type: &DataType, row_count: usize) -> Self {
447 Self {
448 data_type: data_type.clone(),
449 row_count,
450 state: ResultState::Empty,
451 }
452 }
453
454 fn add_branch_result(
487 &mut self,
488 row_indices: &ArrayRef,
489 value: ColumnarValue,
490 ) -> Result<()> {
491 match value {
492 ColumnarValue::Array(a) => {
493 if a.len() != row_indices.len() {
494 internal_err!("Array length must match row indices length")
495 } else if row_indices.len() == self.row_count {
496 self.set_complete_result(ColumnarValue::Array(a))
497 } else {
498 self.add_partial_result(row_indices, a)
499 }
500 }
501 ColumnarValue::Scalar(s) => {
502 if row_indices.len() == self.row_count {
503 self.set_complete_result(ColumnarValue::Scalar(s))
504 } else {
505 self.add_partial_result(
506 row_indices,
507 s.to_array_of_size(row_indices.len())?,
508 )
509 }
510 }
511 }
512 }
513
514 fn add_partial_result(
520 &mut self,
521 row_indices: &ArrayRef,
522 row_values: ArrayRef,
523 ) -> Result<()> {
524 assert_or_internal_err!(
525 row_indices.null_count() == 0,
526 "Row indices must not contain nulls"
527 );
528
529 match &mut self.state {
530 ResultState::Empty => {
531 let array_index = PartialResultIndex::zero();
532 let mut indices = vec![PartialResultIndex::none(); self.row_count];
533 for row_ix in row_indices.as_primitive::<UInt32Type>().values().iter() {
534 indices[*row_ix as usize] = array_index;
535 }
536
537 self.state = ResultState::Partial {
538 arrays: vec![row_values],
539 indices,
540 };
541
542 Ok(())
543 }
544 ResultState::Partial { arrays, indices } => {
545 let array_index = PartialResultIndex::try_new(arrays.len())?;
546
547 arrays.push(row_values);
548
549 for row_ix in row_indices.as_primitive::<UInt32Type>().values().iter() {
550 #[cfg(debug_assertions)]
554 assert_or_internal_err!(
555 indices[*row_ix as usize].is_none(),
556 "Duplicate value for row {}",
557 *row_ix
558 );
559
560 indices[*row_ix as usize] = array_index;
561 }
562 Ok(())
563 }
564 ResultState::Complete(_) => internal_err!(
565 "Cannot add a partial result when complete result is already set"
566 ),
567 }
568 }
569
570 fn set_complete_result(&mut self, value: ColumnarValue) -> Result<()> {
576 match &self.state {
577 ResultState::Empty => {
578 self.state = ResultState::Complete(value);
579 Ok(())
580 }
581 ResultState::Partial { .. } => {
582 internal_err!(
583 "Cannot set a complete result when there are already partial results"
584 )
585 }
586 ResultState::Complete(_) => internal_err!("Complete result already set"),
587 }
588 }
589
590 fn finish(self) -> Result<ColumnarValue> {
592 match self.state {
593 ResultState::Empty => {
594 Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
598 &self.data_type,
599 )?))
600 }
601 ResultState::Partial { arrays, indices } => {
602 let array_refs = arrays.iter().map(|a| a.as_ref()).collect::<Vec<_>>();
604 Ok(ColumnarValue::Array(merge_n(&array_refs, &indices)?))
605 }
606 ResultState::Complete(v) => {
607 Ok(v)
609 }
610 }
611 }
612}
613
614impl CaseExpr {
615 pub fn try_new(
617 expr: Option<Arc<dyn PhysicalExpr>>,
618 when_then_expr: Vec<WhenThen>,
619 else_expr: Option<Arc<dyn PhysicalExpr>>,
620 ) -> Result<Self> {
621 let else_expr = match &else_expr {
624 Some(e) => match e.as_any().downcast_ref::<Literal>() {
625 Some(lit) if lit.value().is_null() => None,
626 _ => else_expr,
627 },
628 _ => else_expr,
629 };
630
631 if when_then_expr.is_empty() {
632 return exec_err!("There must be at least one WHEN clause");
633 }
634
635 let body = CaseBody {
636 expr,
637 when_then_expr,
638 else_expr,
639 };
640
641 let eval_method = Self::find_best_eval_method(&body)?;
642
643 Ok(Self { body, eval_method })
644 }
645
646 fn find_best_eval_method(body: &CaseBody) -> Result<EvalMethod> {
647 if body.expr.is_some() {
648 if let Some(mapping) = LiteralLookupTable::maybe_new(body) {
649 return Ok(EvalMethod::WithExprScalarLookupTable(mapping));
650 }
651
652 return Ok(EvalMethod::WithExpression(body.project()?));
653 }
654
655 Ok(
656 if body.when_then_expr.len() == 1
657 && is_cheap_and_infallible(&(body.when_then_expr[0].1))
658 && body.else_expr.is_none()
659 {
660 EvalMethod::InfallibleExprOrNull
661 } else if body.when_then_expr.len() == 1
662 && body.when_then_expr[0].1.as_any().is::<Literal>()
663 && body.else_expr.is_some()
664 && body.else_expr.as_ref().unwrap().as_any().is::<Literal>()
665 {
666 EvalMethod::ScalarOrScalar
667 } else if body.when_then_expr.len() == 1 {
668 EvalMethod::ExpressionOrExpression(body.project()?)
669 } else {
670 EvalMethod::NoExpression(body.project()?)
671 },
672 )
673 }
674
675 pub fn expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
677 self.body.expr.as_ref()
678 }
679
680 pub fn when_then_expr(&self) -> &[WhenThen] {
682 &self.body.when_then_expr
683 }
684
685 pub fn else_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
687 self.body.else_expr.as_ref()
688 }
689}
690
691impl CaseBody {
692 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
693 let mut data_type = DataType::Null;
696 for i in 0..self.when_then_expr.len() {
697 data_type = self.when_then_expr[i].1.data_type(input_schema)?;
698 if !data_type.equals_datatype(&DataType::Null) {
699 break;
700 }
701 }
702 if data_type.equals_datatype(&DataType::Null)
704 && let Some(e) = &self.else_expr
705 {
706 data_type = e.data_type(input_schema)?;
707 }
708
709 Ok(data_type)
710 }
711
712 fn case_when_with_expr(
714 &self,
715 batch: &RecordBatch,
716 return_type: &DataType,
717 ) -> Result<ColumnarValue> {
718 let mut result_builder = ResultBuilder::new(return_type, batch.num_rows());
719
720 let mut remainder_rows: ArrayRef =
722 Arc::new(UInt32Array::from_iter_values(0..batch.num_rows() as u32));
723 let mut remainder_batch = Cow::Borrowed(batch);
725
726 let mut base_values = self
728 .expr
729 .as_ref()
730 .unwrap()
731 .evaluate(batch)?
732 .into_array(batch.num_rows())?;
733
734 let base_null_count = base_values.logical_null_count();
739 if base_null_count > 0 {
740 let base_not_nulls = is_not_null(base_values.as_ref())?;
744 let base_all_null = base_null_count == remainder_batch.num_rows();
745
746 if let Some(e) = &self.else_expr {
749 let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
750
751 if base_all_null {
752 let nulls_value = expr.evaluate(&remainder_batch)?;
754 result_builder.add_branch_result(&remainder_rows, nulls_value)?;
755 } else {
756 let nulls_filter = create_filter(¬(&base_not_nulls)?, true);
758 let nulls_batch =
759 filter_record_batch(&remainder_batch, &nulls_filter)?;
760 let nulls_rows = filter_array(&remainder_rows, &nulls_filter)?;
761 let nulls_value = expr.evaluate(&nulls_batch)?;
762 result_builder.add_branch_result(&nulls_rows, nulls_value)?;
763 }
764 }
765
766 if base_all_null {
768 return result_builder.finish();
769 }
770
771 let not_null_filter = create_filter(&base_not_nulls, true);
773 remainder_batch =
774 Cow::Owned(filter_record_batch(&remainder_batch, ¬_null_filter)?);
775 remainder_rows = filter_array(&remainder_rows, ¬_null_filter)?;
776 base_values = filter_array(&base_values, ¬_null_filter)?;
777 }
778
779 let base_value_is_nested = base_values.data_type().is_nested();
782
783 for i in 0..self.when_then_expr.len() {
784 let when_expr = &self.when_then_expr[i].0;
787 let when_value = match when_expr.evaluate(&remainder_batch)? {
788 ColumnarValue::Array(a) => {
789 compare_with_eq(&a, &base_values, base_value_is_nested)
790 }
791 ColumnarValue::Scalar(s) => {
792 compare_with_eq(&s.to_scalar()?, &base_values, base_value_is_nested)
793 }
794 }?;
795
796 let when_true_count = when_value.true_count();
799
800 if when_true_count == 0 {
802 continue;
803 }
804
805 if when_true_count == remainder_batch.num_rows() {
807 let then_expression = &self.when_then_expr[i].1;
808 let then_value = then_expression.evaluate(&remainder_batch)?;
809 result_builder.add_branch_result(&remainder_rows, then_value)?;
810 return result_builder.finish();
811 }
812
813 let then_filter = create_filter(&when_value, true);
819 let then_batch = filter_record_batch(&remainder_batch, &then_filter)?;
820 let then_rows = filter_array(&remainder_rows, &then_filter)?;
821
822 let then_expression = &self.when_then_expr[i].1;
823 let then_value = then_expression.evaluate(&then_batch)?;
824 result_builder.add_branch_result(&then_rows, then_value)?;
825
826 if self.else_expr.is_none() && i == self.when_then_expr.len() - 1 {
829 return result_builder.finish();
830 }
831
832 let next_selection = match when_value.null_count() {
834 0 => not(&when_value),
835 _ => {
836 not(&prep_null_mask_filter(&when_value))
839 }
840 }?;
841 let next_filter = create_filter(&next_selection, true);
842 remainder_batch =
843 Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?);
844 remainder_rows = filter_array(&remainder_rows, &next_filter)?;
845 base_values = filter_array(&base_values, &next_filter)?;
846 }
847
848 if let Some(e) = &self.else_expr {
851 let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
853 let else_value = expr.evaluate(&remainder_batch)?;
854 result_builder.add_branch_result(&remainder_rows, else_value)?;
855 }
856
857 result_builder.finish()
858 }
859
860 fn case_when_no_expr(
862 &self,
863 batch: &RecordBatch,
864 return_type: &DataType,
865 ) -> Result<ColumnarValue> {
866 let mut result_builder = ResultBuilder::new(return_type, batch.num_rows());
867
868 let mut remainder_rows: ArrayRef =
870 Arc::new(UInt32Array::from_iter(0..batch.num_rows() as u32));
871 let mut remainder_batch = Cow::Borrowed(batch);
873
874 for i in 0..self.when_then_expr.len() {
875 let when_predicate = &self.when_then_expr[i].0;
878 let when_value = when_predicate
879 .evaluate(&remainder_batch)?
880 .into_array(remainder_batch.num_rows())?;
881 let when_value = as_boolean_array(&when_value).map_err(|_| {
882 internal_datafusion_err!("WHEN expression did not return a BooleanArray")
883 })?;
884
885 let when_true_count = when_value.true_count();
888
889 if when_true_count == 0 {
891 continue;
892 }
893
894 if when_true_count == remainder_batch.num_rows() {
896 let then_expression = &self.when_then_expr[i].1;
897 let then_value = then_expression.evaluate(&remainder_batch)?;
898 result_builder.add_branch_result(&remainder_rows, then_value)?;
899 return result_builder.finish();
900 }
901
902 let then_filter = create_filter(when_value, true);
908 let then_batch = filter_record_batch(&remainder_batch, &then_filter)?;
909 let then_rows = filter_array(&remainder_rows, &then_filter)?;
910
911 let then_expression = &self.when_then_expr[i].1;
912 let then_value = then_expression.evaluate(&then_batch)?;
913 result_builder.add_branch_result(&then_rows, then_value)?;
914
915 if self.else_expr.is_none() && i == self.when_then_expr.len() - 1 {
918 return result_builder.finish();
919 }
920
921 let next_selection = match when_value.null_count() {
923 0 => not(when_value),
924 _ => {
925 not(&prep_null_mask_filter(when_value))
928 }
929 }?;
930 let next_filter = create_filter(&next_selection, true);
931 remainder_batch =
932 Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?);
933 remainder_rows = filter_array(&remainder_rows, &next_filter)?;
934 }
935
936 if let Some(e) = &self.else_expr {
939 let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
941 let else_value = expr.evaluate(&remainder_batch)?;
942 result_builder.add_branch_result(&remainder_rows, else_value)?;
943 }
944
945 result_builder.finish()
946 }
947
948 fn expr_or_expr(
950 &self,
951 batch: &RecordBatch,
952 when_value: &BooleanArray,
953 ) -> Result<ColumnarValue> {
954 let when_value = match when_value.null_count() {
955 0 => Cow::Borrowed(when_value),
956 _ => {
957 Cow::Owned(prep_null_mask_filter(when_value))
959 }
960 };
961
962 let optimize_filter = batch.num_columns() > 1
963 || (batch.num_columns() == 1 && multiple_arrays(batch.column(0).data_type()));
964
965 let when_filter = create_filter(&when_value, optimize_filter);
966 let then_batch = filter_record_batch(batch, &when_filter)?;
967 let then_value = self.when_then_expr[0].1.evaluate(&then_batch)?;
968
969 match &self.else_expr {
970 None => {
971 let then_array = then_value.to_array(when_value.true_count())?;
972 scatter(&when_value, then_array.as_ref()).map(ColumnarValue::Array)
973 }
974 Some(else_expr) => {
975 let else_selection = not(&when_value)?;
976 let else_filter = create_filter(&else_selection, optimize_filter);
977 let else_batch = filter_record_batch(batch, &else_filter)?;
978
979 let return_type = self.data_type(&batch.schema())?;
981 let else_expr =
982 try_cast(Arc::clone(else_expr), &batch.schema(), return_type.clone())
983 .unwrap_or_else(|_| Arc::clone(else_expr));
984
985 let else_value = else_expr.evaluate(&else_batch)?;
986
987 Ok(ColumnarValue::Array(match (then_value, else_value) {
988 (ColumnarValue::Array(t), ColumnarValue::Array(e)) => {
989 merge(&when_value, &t, &e)
990 }
991 (ColumnarValue::Scalar(t), ColumnarValue::Array(e)) => {
992 merge(&when_value, &t.to_scalar()?, &e)
993 }
994 (ColumnarValue::Array(t), ColumnarValue::Scalar(e)) => {
995 merge(&when_value, &t, &e.to_scalar()?)
996 }
997 (ColumnarValue::Scalar(t), ColumnarValue::Scalar(e)) => {
998 merge(&when_value, &t.to_scalar()?, &e.to_scalar()?)
999 }
1000 }?))
1001 }
1002 }
1003 }
1004}
1005
1006impl CaseExpr {
1007 fn case_when_with_expr(
1015 &self,
1016 batch: &RecordBatch,
1017 projected: &ProjectedCaseBody,
1018 ) -> Result<ColumnarValue> {
1019 let return_type = self.data_type(&batch.schema())?;
1020 if projected.projection.len() < batch.num_columns() {
1021 let projected_batch = batch.project(&projected.projection)?;
1022 projected
1023 .body
1024 .case_when_with_expr(&projected_batch, &return_type)
1025 } else {
1026 self.body.case_when_with_expr(batch, &return_type)
1027 }
1028 }
1029
1030 fn case_when_no_expr(
1038 &self,
1039 batch: &RecordBatch,
1040 projected: &ProjectedCaseBody,
1041 ) -> Result<ColumnarValue> {
1042 let return_type = self.data_type(&batch.schema())?;
1043 if projected.projection.len() < batch.num_columns() {
1044 let projected_batch = batch.project(&projected.projection)?;
1045 projected
1046 .body
1047 .case_when_no_expr(&projected_batch, &return_type)
1048 } else {
1049 self.body.case_when_no_expr(batch, &return_type)
1050 }
1051 }
1052
1053 fn case_column_or_null(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
1063 let when_expr = &self.body.when_then_expr[0].0;
1064 let then_expr = &self.body.when_then_expr[0].1;
1065
1066 match when_expr.evaluate(batch)? {
1067 ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))) => {
1069 then_expr.evaluate(batch)
1070 }
1071 ColumnarValue::Scalar(_) => {
1073 ScalarValue::try_from(self.data_type(&batch.schema())?)
1075 .map(ColumnarValue::Scalar)
1076 }
1077 ColumnarValue::Array(bit_mask) => {
1079 let bit_mask = bit_mask
1080 .as_any()
1081 .downcast_ref::<BooleanArray>()
1082 .expect("predicate should evaluate to a boolean array");
1083 let bit_mask = match bit_mask.null_count() {
1085 0 => not(bit_mask)?,
1086 _ => not(&prep_null_mask_filter(bit_mask))?,
1087 };
1088 match then_expr.evaluate(batch)? {
1089 ColumnarValue::Array(array) => {
1090 Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?))
1091 }
1092 ColumnarValue::Scalar(_) => {
1093 internal_err!("expression did not evaluate to an array")
1094 }
1095 }
1096 }
1097 }
1098 }
1099
1100 fn scalar_or_scalar(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
1101 let return_type = self.data_type(&batch.schema())?;
1102
1103 let when_value = self.body.when_then_expr[0].0.evaluate(batch)?;
1105 let when_value = when_value.into_array(batch.num_rows())?;
1106 let when_value = as_boolean_array(&when_value).map_err(|_| {
1107 internal_datafusion_err!("WHEN expression did not return a BooleanArray")
1108 })?;
1109
1110 let when_value = match when_value.null_count() {
1112 0 => Cow::Borrowed(when_value),
1113 _ => Cow::Owned(prep_null_mask_filter(when_value)),
1114 };
1115
1116 let then_value = self.body.when_then_expr[0].1.evaluate(batch)?;
1118 let then_value = Scalar::new(then_value.into_array(1)?);
1119
1120 let Some(e) = &self.body.else_expr else {
1121 return internal_err!("expression did not evaluate to an array");
1122 };
1123 let expr = try_cast(Arc::clone(e), &batch.schema(), return_type)?;
1125 let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?);
1126 Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?))
1127 }
1128
1129 fn expr_or_expr(
1130 &self,
1131 batch: &RecordBatch,
1132 projected: &ProjectedCaseBody,
1133 ) -> Result<ColumnarValue> {
1134 let when_value = self.body.when_then_expr[0].0.evaluate(batch)?;
1136 let when_value = when_value.into_array(1)?;
1140 let when_value = as_boolean_array(&when_value).map_err(|e| {
1141 DataFusionError::Context(
1142 "WHEN expression did not return a BooleanArray".to_string(),
1143 Box::new(e),
1144 )
1145 })?;
1146
1147 let true_count = when_value.true_count();
1148 if true_count == when_value.len() {
1149 self.body.when_then_expr[0].1.evaluate(batch)
1151 } else if true_count == 0 {
1152 match &self.body.else_expr {
1154 Some(else_expr) => else_expr.evaluate(batch),
1155 None => {
1156 let return_type = self.data_type(&batch.schema())?;
1157 Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
1158 &return_type,
1159 )?))
1160 }
1161 }
1162 } else if projected.projection.len() < batch.num_columns() {
1163 let projected_batch = batch.project(&projected.projection)?;
1166 projected.body.expr_or_expr(&projected_batch, when_value)
1167 } else {
1168 self.body.expr_or_expr(batch, when_value)
1170 }
1171 }
1172
1173 fn with_lookup_table(
1174 &self,
1175 batch: &RecordBatch,
1176 lookup_table: &LiteralLookupTable,
1177 ) -> Result<ColumnarValue> {
1178 let expr = self.body.expr.as_ref().unwrap();
1179 let evaluated_expression = expr.evaluate(batch)?;
1180
1181 let is_scalar = matches!(evaluated_expression, ColumnarValue::Scalar(_));
1182 let evaluated_expression = evaluated_expression.to_array(1)?;
1183
1184 let values = lookup_table.map_keys_to_values(&evaluated_expression)?;
1185
1186 let result = if is_scalar {
1187 ColumnarValue::Scalar(ScalarValue::try_from_array(values.as_ref(), 0)?)
1188 } else {
1189 ColumnarValue::Array(values)
1190 };
1191
1192 Ok(result)
1193 }
1194}
1195
1196impl PhysicalExpr for CaseExpr {
1197 fn as_any(&self) -> &dyn Any {
1199 self
1200 }
1201
1202 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
1203 self.body.data_type(input_schema)
1204 }
1205
1206 fn nullable(&self, input_schema: &Schema) -> Result<bool> {
1207 let nullable_then = self
1208 .body
1209 .when_then_expr
1210 .iter()
1211 .filter_map(|(w, t)| {
1212 let is_nullable = match t.nullable(input_schema) {
1213 Err(e) => return Some(Err(e)),
1215 Ok(n) => n,
1216 };
1217
1218 if !is_nullable {
1221 return None;
1222 }
1223
1224 if self.body.expr.is_some() {
1226 return Some(Ok(()));
1227 }
1228
1229 let with_null = match replace_with_null(w, t.as_ref(), input_schema) {
1235 Err(e) => return Some(Err(e)),
1236 Ok(e) => e,
1237 };
1238
1239 let predicate_result = match evaluate_predicate(&with_null) {
1241 Err(e) => return Some(Err(e)),
1242 Ok(b) => b,
1243 };
1244
1245 match predicate_result {
1246 None | Some(true) => Some(Ok(())),
1248 Some(false) => None,
1251 }
1252 })
1253 .next();
1254
1255 if let Some(nullable_then) = nullable_then {
1256 nullable_then.map(|_| true)
1260 } else if let Some(e) = &self.body.else_expr {
1261 e.nullable(input_schema)
1264 } else {
1265 Ok(true)
1268 }
1269 }
1270
1271 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
1272 match &self.eval_method {
1273 EvalMethod::WithExpression(p) => {
1274 self.case_when_with_expr(batch, p)
1277 }
1278 EvalMethod::NoExpression(p) => {
1279 self.case_when_no_expr(batch, p)
1282 }
1283 EvalMethod::InfallibleExprOrNull => {
1284 self.case_column_or_null(batch)
1286 }
1287 EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch),
1288 EvalMethod::ExpressionOrExpression(p) => self.expr_or_expr(batch, p),
1289 EvalMethod::WithExprScalarLookupTable(lookup_table) => {
1290 self.with_lookup_table(batch, lookup_table)
1291 }
1292 }
1293 }
1294
1295 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
1296 let mut children = vec![];
1297 if let Some(expr) = &self.body.expr {
1298 children.push(expr)
1299 }
1300 self.body.when_then_expr.iter().for_each(|(cond, value)| {
1301 children.push(cond);
1302 children.push(value);
1303 });
1304
1305 if let Some(else_expr) = &self.body.else_expr {
1306 children.push(else_expr)
1307 }
1308 children
1309 }
1310
1311 fn with_new_children(
1313 self: Arc<Self>,
1314 children: Vec<Arc<dyn PhysicalExpr>>,
1315 ) -> Result<Arc<dyn PhysicalExpr>> {
1316 if children.len() != self.children().len() {
1317 internal_err!("CaseExpr: Wrong number of children")
1318 } else {
1319 let (expr, when_then_expr, else_expr) =
1320 match (self.expr().is_some(), self.body.else_expr.is_some()) {
1321 (true, true) => (
1322 Some(&children[0]),
1323 &children[1..children.len() - 1],
1324 Some(&children[children.len() - 1]),
1325 ),
1326 (true, false) => {
1327 (Some(&children[0]), &children[1..children.len()], None)
1328 }
1329 (false, true) => (
1330 None,
1331 &children[0..children.len() - 1],
1332 Some(&children[children.len() - 1]),
1333 ),
1334 (false, false) => (None, &children[0..children.len()], None),
1335 };
1336 Ok(Arc::new(CaseExpr::try_new(
1337 expr.cloned(),
1338 when_then_expr.iter().cloned().tuples().collect(),
1339 else_expr.cloned(),
1340 )?))
1341 }
1342 }
1343
1344 fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1345 write!(f, "CASE ")?;
1346 if let Some(e) = &self.body.expr {
1347 e.fmt_sql(f)?;
1348 write!(f, " ")?;
1349 }
1350
1351 for (w, t) in &self.body.when_then_expr {
1352 write!(f, "WHEN ")?;
1353 w.fmt_sql(f)?;
1354 write!(f, " THEN ")?;
1355 t.fmt_sql(f)?;
1356 write!(f, " ")?;
1357 }
1358
1359 if let Some(e) = &self.body.else_expr {
1360 write!(f, "ELSE ")?;
1361 e.fmt_sql(f)?;
1362 write!(f, " ")?;
1363 }
1364 write!(f, "END")
1365 }
1366}
1367
1368fn evaluate_predicate(predicate: &Arc<dyn PhysicalExpr>) -> Result<Option<bool>> {
1374 let batch = RecordBatch::try_new_with_options(
1376 Arc::new(Schema::empty()),
1377 vec![],
1378 &RecordBatchOptions::new().with_row_count(Some(1)),
1379 )?;
1380
1381 let result = match predicate.evaluate(&batch) {
1383 Err(_) => None,
1385 Ok(ColumnarValue::Array(array)) => Some(
1386 ScalarValue::try_from_array(array.as_ref(), 0)?
1387 .cast_to(&DataType::Boolean)?,
1388 ),
1389 Ok(ColumnarValue::Scalar(scalar)) => Some(scalar.cast_to(&DataType::Boolean)?),
1390 };
1391 Ok(result.map(|v| matches!(v, ScalarValue::Boolean(Some(true)))))
1392}
1393
1394fn replace_with_null(
1395 expr: &Arc<dyn PhysicalExpr>,
1396 expr_to_replace: &dyn PhysicalExpr,
1397 input_schema: &Schema,
1398) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
1399 let with_null = Arc::clone(expr)
1400 .transform_down(|e| {
1401 if e.as_ref().dyn_eq(expr_to_replace) {
1402 let data_type = e.data_type(input_schema)?;
1403 let null_literal = lit(ScalarValue::try_new_null(&data_type)?);
1404 Ok(Transformed::yes(null_literal))
1405 } else {
1406 Ok(Transformed::no(e))
1407 }
1408 })?
1409 .data;
1410 Ok(with_null)
1411}
1412
1413pub fn case(
1415 expr: Option<Arc<dyn PhysicalExpr>>,
1416 when_thens: Vec<WhenThen>,
1417 else_expr: Option<Arc<dyn PhysicalExpr>>,
1418) -> Result<Arc<dyn PhysicalExpr>> {
1419 Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?))
1420}
1421
1422#[cfg(test)]
1423mod tests {
1424 use super::*;
1425
1426 use crate::expressions;
1427 use crate::expressions::{BinaryExpr, binary, cast, col, is_not_null, lit};
1428 use arrow::buffer::Buffer;
1429 use arrow::datatypes::DataType::Float64;
1430 use arrow::datatypes::Field;
1431 use datafusion_common::cast::{as_float64_array, as_int32_array};
1432 use datafusion_common::plan_err;
1433 use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
1434 use datafusion_expr::type_coercion::binary::comparison_coercion;
1435 use datafusion_expr_common::operator::Operator;
1436 use datafusion_physical_expr_common::physical_expr::fmt_sql;
1437 use half::f16;
1438
1439 #[test]
1440 fn case_with_expr() -> Result<()> {
1441 let batch = case_test_batch()?;
1442 let schema = batch.schema();
1443
1444 let when1 = lit("foo");
1446 let then1 = lit(123i32);
1447 let when2 = lit("bar");
1448 let then2 = lit(456i32);
1449
1450 let expr = generate_case_when_with_type_coercion(
1451 Some(col("a", &schema)?),
1452 vec![(when1, then1), (when2, then2)],
1453 None,
1454 schema.as_ref(),
1455 )?;
1456 let result = expr
1457 .evaluate(&batch)?
1458 .into_array(batch.num_rows())
1459 .expect("Failed to convert to array");
1460 let result = as_int32_array(&result)?;
1461
1462 let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1463
1464 assert_eq!(expected, result);
1465
1466 Ok(())
1467 }
1468
1469 #[test]
1470 fn case_with_expr_dictionary() -> Result<()> {
1471 let schema = Schema::new(vec![Field::new(
1472 "a",
1473 DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
1474 true,
1475 )]);
1476 let keys = UInt8Array::from(vec![0u8, 1u8, 2u8, 3u8]);
1477 let values = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
1478 let dictionary = DictionaryArray::new(keys, Arc::new(values));
1479 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?;
1480
1481 let schema = batch.schema();
1482
1483 let when1 = lit("foo");
1485 let then1 = lit(123i32);
1486 let when2 = lit("bar");
1487 let then2 = lit(456i32);
1488
1489 let expr = generate_case_when_with_type_coercion(
1490 Some(col("a", &schema)?),
1491 vec![(when1, then1), (when2, then2)],
1492 None,
1493 schema.as_ref(),
1494 )?;
1495 let result = expr
1496 .evaluate(&batch)?
1497 .into_array(batch.num_rows())
1498 .expect("Failed to convert to array");
1499 let result = as_int32_array(&result)?;
1500
1501 let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1502
1503 assert_eq!(expected, result);
1504
1505 Ok(())
1506 }
1507
1508 #[test]
1510 fn case_with_expr_primitive_dictionary() -> Result<()> {
1511 let schema = Schema::new(vec![Field::new(
1512 "a",
1513 DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::UInt64)),
1514 true,
1515 )]);
1516 let keys = UInt8Array::from(vec![0u8, 1u8, 2u8, 3u8]);
1517 let values = UInt64Array::from(vec![Some(10), Some(20), None, Some(30)]);
1518 let dictionary = DictionaryArray::new(keys, Arc::new(values));
1519 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?;
1520
1521 let schema = batch.schema();
1522
1523 let when1 = lit(10_u64);
1525 let then1 = lit(123_i32);
1526 let when2 = lit(30_u64);
1527 let then2 = lit(456_i32);
1528
1529 let expr = generate_case_when_with_type_coercion(
1530 Some(col("a", &schema)?),
1531 vec![(when1, then1), (when2, then2)],
1532 None,
1533 schema.as_ref(),
1534 )?;
1535 let result = expr
1536 .evaluate(&batch)?
1537 .into_array(batch.num_rows())
1538 .expect("Failed to convert to array");
1539 let result = as_int32_array(&result)?;
1540
1541 let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1542
1543 assert_eq!(expected, result);
1544
1545 Ok(())
1546 }
1547
1548 #[test]
1550 fn case_with_expr_boolean_dictionary() -> Result<()> {
1551 let schema = Schema::new(vec![Field::new(
1552 "a",
1553 DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Boolean)),
1554 true,
1555 )]);
1556 let keys = UInt8Array::from(vec![0u8, 1u8, 2u8, 3u8]);
1557 let values = BooleanArray::from(vec![Some(true), Some(false), None, Some(true)]);
1558 let dictionary = DictionaryArray::new(keys, Arc::new(values));
1559 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?;
1560
1561 let schema = batch.schema();
1562
1563 let when1 = lit(true);
1565 let then1 = lit(123i32);
1566 let when2 = lit(false);
1567 let then2 = lit(456i32);
1568
1569 let expr = generate_case_when_with_type_coercion(
1570 Some(col("a", &schema)?),
1571 vec![(when1, then1), (when2, then2)],
1572 None,
1573 schema.as_ref(),
1574 )?;
1575 let result = expr
1576 .evaluate(&batch)?
1577 .into_array(batch.num_rows())
1578 .expect("Failed to convert to array");
1579 let result = as_int32_array(&result)?;
1580
1581 let expected = &Int32Array::from(vec![Some(123), Some(456), None, Some(123)]);
1582
1583 assert_eq!(expected, result);
1584
1585 Ok(())
1586 }
1587
1588 #[test]
1589 fn case_with_expr_all_null_dictionary() -> Result<()> {
1590 let schema = Schema::new(vec![Field::new(
1591 "a",
1592 DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
1593 true,
1594 )]);
1595 let keys = UInt8Array::from(vec![2u8, 2u8, 2u8, 2u8]);
1596 let values = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
1597 let dictionary = DictionaryArray::new(keys, Arc::new(values));
1598 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?;
1599
1600 let schema = batch.schema();
1601
1602 let when1 = lit("foo");
1604 let then1 = lit(123i32);
1605 let when2 = lit("bar");
1606 let then2 = lit(456i32);
1607
1608 let expr = generate_case_when_with_type_coercion(
1609 Some(col("a", &schema)?),
1610 vec![(when1, then1), (when2, then2)],
1611 None,
1612 schema.as_ref(),
1613 )?;
1614 let result = expr
1615 .evaluate(&batch)?
1616 .into_array(batch.num_rows())
1617 .expect("Failed to convert to array");
1618 let result = as_int32_array(&result)?;
1619
1620 let expected = &Int32Array::from(vec![None, None, None, None]);
1621
1622 assert_eq!(expected, result);
1623
1624 Ok(())
1625 }
1626
1627 #[test]
1628 fn case_with_expr_else() -> Result<()> {
1629 let batch = case_test_batch()?;
1630 let schema = batch.schema();
1631
1632 let when1 = lit("foo");
1634 let then1 = lit(123i32);
1635 let when2 = lit("bar");
1636 let then2 = lit(456i32);
1637 let else_value = lit(999i32);
1638
1639 let expr = generate_case_when_with_type_coercion(
1640 Some(col("a", &schema)?),
1641 vec![(when1, then1), (when2, then2)],
1642 Some(else_value),
1643 schema.as_ref(),
1644 )?;
1645 let result = expr
1646 .evaluate(&batch)?
1647 .into_array(batch.num_rows())
1648 .expect("Failed to convert to array");
1649 let result = as_int32_array(&result)?;
1650
1651 let expected =
1652 &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
1653
1654 assert_eq!(expected, result);
1655
1656 Ok(())
1657 }
1658
1659 #[test]
1660 fn case_with_expr_divide_by_zero() -> Result<()> {
1661 let batch = case_test_batch1()?;
1662 let schema = batch.schema();
1663
1664 let when1 = lit(0i32);
1666 let then1 = lit(ScalarValue::Float64(None));
1667 let else_value = binary(
1668 lit(25.0f64),
1669 Operator::Divide,
1670 cast(col("a", &schema)?, &batch.schema(), Float64)?,
1671 &batch.schema(),
1672 )?;
1673
1674 let expr = generate_case_when_with_type_coercion(
1675 Some(col("a", &schema)?),
1676 vec![(when1, then1)],
1677 Some(else_value),
1678 schema.as_ref(),
1679 )?;
1680 let result = expr
1681 .evaluate(&batch)?
1682 .into_array(batch.num_rows())
1683 .expect("Failed to convert to array");
1684 let result =
1685 as_float64_array(&result).expect("failed to downcast to Float64Array");
1686
1687 let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
1688
1689 assert_eq!(expected, result);
1690
1691 Ok(())
1692 }
1693
1694 #[test]
1695 fn case_without_expr() -> Result<()> {
1696 let batch = case_test_batch()?;
1697 let schema = batch.schema();
1698
1699 let when1 = binary(
1701 col("a", &schema)?,
1702 Operator::Eq,
1703 lit("foo"),
1704 &batch.schema(),
1705 )?;
1706 let then1 = lit(123i32);
1707 let when2 = binary(
1708 col("a", &schema)?,
1709 Operator::Eq,
1710 lit("bar"),
1711 &batch.schema(),
1712 )?;
1713 let then2 = lit(456i32);
1714
1715 let expr = generate_case_when_with_type_coercion(
1716 None,
1717 vec![(when1, then1), (when2, then2)],
1718 None,
1719 schema.as_ref(),
1720 )?;
1721 let result = expr
1722 .evaluate(&batch)?
1723 .into_array(batch.num_rows())
1724 .expect("Failed to convert to array");
1725 let result = as_int32_array(&result)?;
1726
1727 let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1728
1729 assert_eq!(expected, result);
1730
1731 Ok(())
1732 }
1733
1734 #[test]
1735 fn case_with_expr_when_null() -> Result<()> {
1736 let batch = case_test_batch()?;
1737 let schema = batch.schema();
1738
1739 let when1 = lit(ScalarValue::Utf8(None));
1741 let then1 = lit(0i32);
1742 let when2 = col("a", &schema)?;
1743 let then2 = lit(123i32);
1744 let else_value = lit(999i32);
1745
1746 let expr = generate_case_when_with_type_coercion(
1747 Some(col("a", &schema)?),
1748 vec![(when1, then1), (when2, then2)],
1749 Some(else_value),
1750 schema.as_ref(),
1751 )?;
1752 let result = expr
1753 .evaluate(&batch)?
1754 .into_array(batch.num_rows())
1755 .expect("Failed to convert to array");
1756 let result = as_int32_array(&result)?;
1757
1758 let expected =
1759 &Int32Array::from(vec![Some(123), Some(123), Some(999), Some(123)]);
1760
1761 assert_eq!(expected, result);
1762
1763 Ok(())
1764 }
1765
1766 #[test]
1767 fn case_without_expr_divide_by_zero() -> Result<()> {
1768 let batch = case_test_batch1()?;
1769 let schema = batch.schema();
1770
1771 let when1 = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &batch.schema())?;
1773 let then1 = binary(
1774 lit(25.0f64),
1775 Operator::Divide,
1776 cast(col("a", &schema)?, &batch.schema(), Float64)?,
1777 &batch.schema(),
1778 )?;
1779 let x = lit(ScalarValue::Float64(None));
1780
1781 let expr = generate_case_when_with_type_coercion(
1782 None,
1783 vec![(when1, then1)],
1784 Some(x),
1785 schema.as_ref(),
1786 )?;
1787 let result = expr
1788 .evaluate(&batch)?
1789 .into_array(batch.num_rows())
1790 .expect("Failed to convert to array");
1791 let result =
1792 as_float64_array(&result).expect("failed to downcast to Float64Array");
1793
1794 let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
1795
1796 assert_eq!(expected, result);
1797
1798 Ok(())
1799 }
1800
1801 fn case_test_batch1() -> Result<RecordBatch> {
1802 let schema = Schema::new(vec![
1803 Field::new("a", DataType::Int32, true),
1804 Field::new("b", DataType::Int32, true),
1805 Field::new("c", DataType::Int32, true),
1806 ]);
1807 let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]);
1808 let b = Int32Array::from(vec![Some(3), None, Some(14), Some(7)]);
1809 let c = Int32Array::from(vec![Some(0), Some(-3), Some(777), None]);
1810 let batch = RecordBatch::try_new(
1811 Arc::new(schema),
1812 vec![Arc::new(a), Arc::new(b), Arc::new(c)],
1813 )?;
1814 Ok(batch)
1815 }
1816
1817 #[test]
1818 fn case_without_expr_else() -> Result<()> {
1819 let batch = case_test_batch()?;
1820 let schema = batch.schema();
1821
1822 let when1 = binary(
1824 col("a", &schema)?,
1825 Operator::Eq,
1826 lit("foo"),
1827 &batch.schema(),
1828 )?;
1829 let then1 = lit(123i32);
1830 let when2 = binary(
1831 col("a", &schema)?,
1832 Operator::Eq,
1833 lit("bar"),
1834 &batch.schema(),
1835 )?;
1836 let then2 = lit(456i32);
1837 let else_value = lit(999i32);
1838
1839 let expr = generate_case_when_with_type_coercion(
1840 None,
1841 vec![(when1, then1), (when2, then2)],
1842 Some(else_value),
1843 schema.as_ref(),
1844 )?;
1845 let result = expr
1846 .evaluate(&batch)?
1847 .into_array(batch.num_rows())
1848 .expect("Failed to convert to array");
1849 let result = as_int32_array(&result)?;
1850
1851 let expected =
1852 &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
1853
1854 assert_eq!(expected, result);
1855
1856 Ok(())
1857 }
1858
1859 #[test]
1860 fn case_with_type_cast() -> Result<()> {
1861 let batch = case_test_batch()?;
1862 let schema = batch.schema();
1863
1864 let when = binary(
1866 col("a", &schema)?,
1867 Operator::Eq,
1868 lit("foo"),
1869 &batch.schema(),
1870 )?;
1871 let then = lit(123.3f64);
1872 let else_value = lit(999i32);
1873
1874 let expr = generate_case_when_with_type_coercion(
1875 None,
1876 vec![(when, then)],
1877 Some(else_value),
1878 schema.as_ref(),
1879 )?;
1880 let result = expr
1881 .evaluate(&batch)?
1882 .into_array(batch.num_rows())
1883 .expect("Failed to convert to array");
1884 let result =
1885 as_float64_array(&result).expect("failed to downcast to Float64Array");
1886
1887 let expected =
1888 &Float64Array::from(vec![Some(123.3), Some(999.0), Some(999.0), Some(999.0)]);
1889
1890 assert_eq!(expected, result);
1891
1892 Ok(())
1893 }
1894
1895 #[test]
1896 fn case_with_matches_and_nulls() -> Result<()> {
1897 let batch = case_test_batch_nulls()?;
1898 let schema = batch.schema();
1899
1900 let when = binary(
1902 col("load4", &schema)?,
1903 Operator::Eq,
1904 lit(1.77f64),
1905 &batch.schema(),
1906 )?;
1907 let then = col("load4", &schema)?;
1908
1909 let expr = generate_case_when_with_type_coercion(
1910 None,
1911 vec![(when, then)],
1912 None,
1913 schema.as_ref(),
1914 )?;
1915 let result = expr
1916 .evaluate(&batch)?
1917 .into_array(batch.num_rows())
1918 .expect("Failed to convert to array");
1919 let result =
1920 as_float64_array(&result).expect("failed to downcast to Float64Array");
1921
1922 let expected =
1923 &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
1924
1925 assert_eq!(expected, result);
1926
1927 Ok(())
1928 }
1929
1930 #[test]
1931 fn case_with_scalar_predicate() -> Result<()> {
1932 let batch = case_test_batch_nulls()?;
1933 let schema = batch.schema();
1934
1935 let when = lit(true);
1937 let then = col("load4", &schema)?;
1938 let expr = generate_case_when_with_type_coercion(
1939 None,
1940 vec![(when, then)],
1941 None,
1942 schema.as_ref(),
1943 )?;
1944
1945 let result = expr
1947 .evaluate(&batch)?
1948 .into_array(batch.num_rows())
1949 .expect("Failed to convert to array");
1950 let result =
1951 as_float64_array(&result).expect("failed to downcast to Float64Array");
1952 let expected = &Float64Array::from(vec![
1953 Some(1.77),
1954 None,
1955 None,
1956 Some(1.78),
1957 None,
1958 Some(1.77),
1959 ]);
1960 assert_eq!(expected, result);
1961
1962 let expected = Float64Array::from(vec![Some(1.1)]);
1964 let batch =
1965 RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(expected.clone())])?;
1966 let result = expr
1967 .evaluate(&batch)?
1968 .into_array(batch.num_rows())
1969 .expect("Failed to convert to array");
1970 let result =
1971 as_float64_array(&result).expect("failed to downcast to Float64Array");
1972 assert_eq!(&expected, result);
1973
1974 Ok(())
1975 }
1976
1977 #[test]
1978 fn case_expr_matches_and_nulls() -> Result<()> {
1979 let batch = case_test_batch_nulls()?;
1980 let schema = batch.schema();
1981
1982 let expr = col("load4", &schema)?;
1984 let when = lit(1.77f64);
1985 let then = col("load4", &schema)?;
1986
1987 let expr = generate_case_when_with_type_coercion(
1988 Some(expr),
1989 vec![(when, then)],
1990 None,
1991 schema.as_ref(),
1992 )?;
1993 let result = expr
1994 .evaluate(&batch)?
1995 .into_array(batch.num_rows())
1996 .expect("Failed to convert to array");
1997 let result =
1998 as_float64_array(&result).expect("failed to downcast to Float64Array");
1999
2000 let expected =
2001 &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
2002
2003 assert_eq!(expected, result);
2004
2005 Ok(())
2006 }
2007
2008 #[test]
2009 fn test_when_null_and_some_cond_else_null() -> Result<()> {
2010 let batch = case_test_batch()?;
2011 let schema = batch.schema();
2012
2013 let when = binary(
2014 Arc::new(Literal::new(ScalarValue::Boolean(None))),
2015 Operator::And,
2016 binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?,
2017 &schema,
2018 )?;
2019 let then = col("a", &schema)?;
2020
2021 let expr = Arc::new(CaseExpr::try_new(None, vec![(when, then)], None)?);
2023 let result = expr
2024 .evaluate(&batch)?
2025 .into_array(batch.num_rows())
2026 .expect("Failed to convert to array");
2027 let result = as_string_array(&result);
2028
2029 assert_eq!(result.logical_null_count(), batch.num_rows());
2031 Ok(())
2032 }
2033
2034 fn case_test_batch() -> Result<RecordBatch> {
2035 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
2036 let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
2037 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
2038 Ok(batch)
2039 }
2040
2041 fn case_test_batch_nulls() -> Result<RecordBatch> {
2044 let load4: Float64Array = vec![
2045 Some(1.77), Some(1.77), Some(1.77), Some(1.78), None, Some(1.77), ]
2052 .into_iter()
2053 .collect();
2054
2055 let null_buffer = Buffer::from([0b00101001u8]);
2056 let load4 = load4
2057 .into_data()
2058 .into_builder()
2059 .null_bit_buffer(Some(null_buffer))
2060 .build()
2061 .unwrap();
2062 let load4: Float64Array = load4.into();
2063
2064 let batch =
2065 RecordBatch::try_from_iter(vec![("load4", Arc::new(load4) as ArrayRef)])?;
2066 Ok(batch)
2067 }
2068
2069 #[test]
2070 fn case_test_incompatible() -> Result<()> {
2071 let batch = case_test_batch()?;
2074 let schema = batch.schema();
2075
2076 let when1 = binary(
2078 col("a", &schema)?,
2079 Operator::Eq,
2080 lit("foo"),
2081 &batch.schema(),
2082 )?;
2083 let then1 = lit(123i32);
2084 let when2 = binary(
2085 col("a", &schema)?,
2086 Operator::Eq,
2087 lit("bar"),
2088 &batch.schema(),
2089 )?;
2090 let then2 = lit(true);
2091
2092 let expr = generate_case_when_with_type_coercion(
2093 None,
2094 vec![(when1, then1), (when2, then2)],
2095 None,
2096 schema.as_ref(),
2097 );
2098 assert!(expr.is_err());
2099
2100 let when1 = binary(
2105 col("a", &schema)?,
2106 Operator::Eq,
2107 lit("foo"),
2108 &batch.schema(),
2109 )?;
2110 let then1 = lit(123i32);
2111 let when2 = binary(
2112 col("a", &schema)?,
2113 Operator::Eq,
2114 lit("bar"),
2115 &batch.schema(),
2116 )?;
2117 let then2 = lit(456i64);
2118 let else_expr = lit(1.23f64);
2119
2120 let expr = generate_case_when_with_type_coercion(
2121 None,
2122 vec![(when1, then1), (when2, then2)],
2123 Some(else_expr),
2124 schema.as_ref(),
2125 );
2126 assert!(expr.is_ok());
2127 let result_type = expr.unwrap().data_type(schema.as_ref())?;
2128 assert_eq!(Float64, result_type);
2129 Ok(())
2130 }
2131
2132 #[test]
2133 fn case_eq() -> Result<()> {
2134 let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
2135
2136 let when1 = lit("foo");
2137 let then1 = lit(123i32);
2138 let when2 = lit("bar");
2139 let then2 = lit(456i32);
2140 let else_value = lit(999i32);
2141
2142 let expr1 = generate_case_when_with_type_coercion(
2143 Some(col("a", &schema)?),
2144 vec![
2145 (Arc::clone(&when1), Arc::clone(&then1)),
2146 (Arc::clone(&when2), Arc::clone(&then2)),
2147 ],
2148 Some(Arc::clone(&else_value)),
2149 &schema,
2150 )?;
2151
2152 let expr2 = generate_case_when_with_type_coercion(
2153 Some(col("a", &schema)?),
2154 vec![
2155 (Arc::clone(&when1), Arc::clone(&then1)),
2156 (Arc::clone(&when2), Arc::clone(&then2)),
2157 ],
2158 Some(Arc::clone(&else_value)),
2159 &schema,
2160 )?;
2161
2162 let expr3 = generate_case_when_with_type_coercion(
2163 Some(col("a", &schema)?),
2164 vec![(Arc::clone(&when1), Arc::clone(&then1)), (when2, then2)],
2165 None,
2166 &schema,
2167 )?;
2168
2169 let expr4 = generate_case_when_with_type_coercion(
2170 Some(col("a", &schema)?),
2171 vec![(when1, then1)],
2172 Some(else_value),
2173 &schema,
2174 )?;
2175
2176 assert!(expr1.eq(&expr2));
2177 assert!(expr2.eq(&expr1));
2178
2179 assert!(expr2.ne(&expr3));
2180 assert!(expr3.ne(&expr2));
2181
2182 assert!(expr1.ne(&expr4));
2183 assert!(expr4.ne(&expr1));
2184
2185 Ok(())
2186 }
2187
2188 #[test]
2189 fn case_transform() -> Result<()> {
2190 let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
2191
2192 let when1 = lit("foo");
2193 let then1 = lit(123i32);
2194 let when2 = lit("bar");
2195 let then2 = lit(456i32);
2196 let else_value = lit(999i32);
2197
2198 let expr = generate_case_when_with_type_coercion(
2199 Some(col("a", &schema)?),
2200 vec![
2201 (Arc::clone(&when1), Arc::clone(&then1)),
2202 (Arc::clone(&when2), Arc::clone(&then2)),
2203 ],
2204 Some(Arc::clone(&else_value)),
2205 &schema,
2206 )?;
2207
2208 let expr2 = Arc::clone(&expr)
2209 .transform(|e| {
2210 let transformed = match e.as_any().downcast_ref::<Literal>() {
2211 Some(lit_value) => match lit_value.value() {
2212 ScalarValue::Utf8(Some(str_value)) => {
2213 Some(lit(str_value.to_uppercase()))
2214 }
2215 _ => None,
2216 },
2217 _ => None,
2218 };
2219 Ok(if let Some(transformed) = transformed {
2220 Transformed::yes(transformed)
2221 } else {
2222 Transformed::no(e)
2223 })
2224 })
2225 .data()
2226 .unwrap();
2227
2228 let expr3 = Arc::clone(&expr)
2229 .transform_down(|e| {
2230 let transformed = match e.as_any().downcast_ref::<Literal>() {
2231 Some(lit_value) => match lit_value.value() {
2232 ScalarValue::Utf8(Some(str_value)) => {
2233 Some(lit(str_value.to_uppercase()))
2234 }
2235 _ => None,
2236 },
2237 _ => None,
2238 };
2239 Ok(if let Some(transformed) = transformed {
2240 Transformed::yes(transformed)
2241 } else {
2242 Transformed::no(e)
2243 })
2244 })
2245 .data()
2246 .unwrap();
2247
2248 assert!(expr.ne(&expr2));
2249 assert!(expr2.eq(&expr3));
2250
2251 Ok(())
2252 }
2253
2254 #[test]
2255 fn test_column_or_null_specialization() -> Result<()> {
2256 let mut c1 = Int32Builder::new();
2258 let mut c2 = StringBuilder::new();
2259 for i in 0..1000 {
2260 c1.append_value(i);
2261 if i % 7 == 0 {
2262 c2.append_null();
2263 } else {
2264 c2.append_value(format!("string {i}"));
2265 }
2266 }
2267 let c1 = Arc::new(c1.finish());
2268 let c2 = Arc::new(c2.finish());
2269 let schema = Schema::new(vec![
2270 Field::new("c1", DataType::Int32, true),
2271 Field::new("c2", DataType::Utf8, true),
2272 ]);
2273 let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap();
2274
2275 let predicate = Arc::new(BinaryExpr::new(
2277 make_col("c1", 0),
2278 Operator::LtEq,
2279 make_lit_i32(250),
2280 ));
2281 let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?;
2282 assert_eq!(expr.eval_method, EvalMethod::InfallibleExprOrNull);
2283 match expr.evaluate(&batch)? {
2284 ColumnarValue::Array(array) => {
2285 assert_eq!(1000, array.len());
2286 assert_eq!(785, array.null_count());
2287 }
2288 _ => unreachable!(),
2289 }
2290 Ok(())
2291 }
2292
2293 #[test]
2294 fn test_expr_or_expr_specialization() -> Result<()> {
2295 let batch = case_test_batch1()?;
2296 let schema = batch.schema();
2297 let when = binary(
2298 col("a", &schema)?,
2299 Operator::LtEq,
2300 lit(2i32),
2301 &batch.schema(),
2302 )?;
2303 let then = col("b", &schema)?;
2304 let else_expr = col("c", &schema)?;
2305 let expr = CaseExpr::try_new(None, vec![(when, then)], Some(else_expr))?;
2306 assert!(matches!(
2307 expr.eval_method,
2308 EvalMethod::ExpressionOrExpression(_)
2309 ));
2310 let result = expr
2311 .evaluate(&batch)?
2312 .into_array(batch.num_rows())
2313 .expect("Failed to convert to array");
2314 let result = as_int32_array(&result).expect("failed to downcast to Int32Array");
2315
2316 let expected = &Int32Array::from(vec![Some(3), None, Some(777), None]);
2317
2318 assert_eq!(expected, result);
2319 Ok(())
2320 }
2321
2322 fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
2323 Arc::new(Column::new(name, index))
2324 }
2325
2326 fn make_lit_i32(n: i32) -> Arc<dyn PhysicalExpr> {
2327 Arc::new(Literal::new(ScalarValue::Int32(Some(n))))
2328 }
2329
2330 fn generate_case_when_with_type_coercion(
2331 expr: Option<Arc<dyn PhysicalExpr>>,
2332 when_thens: Vec<WhenThen>,
2333 else_expr: Option<Arc<dyn PhysicalExpr>>,
2334 input_schema: &Schema,
2335 ) -> Result<Arc<dyn PhysicalExpr>> {
2336 let coerce_type =
2337 get_case_common_type(&when_thens, else_expr.clone(), input_schema);
2338 let (when_thens, else_expr) = match coerce_type {
2339 None => plan_err!(
2340 "Can't get a common type for then {when_thens:?} and else {else_expr:?} expression"
2341 ),
2342 Some(data_type) => {
2343 let left = when_thens
2345 .into_iter()
2346 .map(|(when, then)| {
2347 let then = try_cast(then, input_schema, data_type.clone())?;
2348 Ok((when, then))
2349 })
2350 .collect::<Result<Vec<_>>>()?;
2351 let right = match else_expr {
2352 None => None,
2353 Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?),
2354 };
2355
2356 Ok((left, right))
2357 }
2358 }?;
2359 case(expr, when_thens, else_expr)
2360 }
2361
2362 fn get_case_common_type(
2363 when_thens: &[WhenThen],
2364 else_expr: Option<Arc<dyn PhysicalExpr>>,
2365 input_schema: &Schema,
2366 ) -> Option<DataType> {
2367 let thens_type = when_thens
2368 .iter()
2369 .map(|when_then| {
2370 let data_type = &when_then.1.data_type(input_schema).unwrap();
2371 data_type.clone()
2372 })
2373 .collect::<Vec<_>>();
2374 let else_type = match else_expr {
2375 None => {
2376 thens_type[0].clone()
2378 }
2379 Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(),
2380 };
2381 thens_type
2382 .iter()
2383 .try_fold(else_type, |left_type, right_type| {
2384 comparison_coercion(&left_type, right_type)
2387 })
2388 }
2389
2390 #[test]
2391 fn test_fmt_sql() -> Result<()> {
2392 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
2393
2394 let when = binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?;
2396 let then = lit(123.3f64);
2397 let else_value = lit(999i32);
2398
2399 let expr = generate_case_when_with_type_coercion(
2400 None,
2401 vec![(when, then)],
2402 Some(else_value),
2403 &schema,
2404 )?;
2405
2406 let display_string = expr.to_string();
2407 assert_eq!(
2408 display_string,
2409 "CASE WHEN a@0 = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END"
2410 );
2411
2412 let sql_string = fmt_sql(expr.as_ref()).to_string();
2413 assert_eq!(
2414 sql_string,
2415 "CASE WHEN a = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END"
2416 );
2417
2418 Ok(())
2419 }
2420
2421 fn when_then_else(
2422 when: &Arc<dyn PhysicalExpr>,
2423 then: &Arc<dyn PhysicalExpr>,
2424 els: &Arc<dyn PhysicalExpr>,
2425 ) -> Result<Arc<dyn PhysicalExpr>> {
2426 let case = CaseExpr::try_new(
2427 None,
2428 vec![(Arc::clone(when), Arc::clone(then))],
2429 Some(Arc::clone(els)),
2430 )?;
2431 Ok(Arc::new(case))
2432 }
2433
2434 #[test]
2435 fn test_case_expression_nullability_with_nullable_column() -> Result<()> {
2436 case_expression_nullability(true)
2437 }
2438
2439 #[test]
2440 fn test_case_expression_nullability_with_not_nullable_column() -> Result<()> {
2441 case_expression_nullability(false)
2442 }
2443
2444 fn case_expression_nullability(col_is_nullable: bool) -> Result<()> {
2445 let schema =
2446 Schema::new(vec![Field::new("foo", DataType::Int32, col_is_nullable)]);
2447
2448 let foo = col("foo", &schema)?;
2449 let foo_is_not_null = is_not_null(Arc::clone(&foo))?;
2450 let foo_is_null = expressions::is_null(Arc::clone(&foo))?;
2451 let not_foo_is_null = expressions::not(Arc::clone(&foo_is_null))?;
2452 let zero = lit(0);
2453 let foo_eq_zero =
2454 binary(Arc::clone(&foo), Operator::Eq, Arc::clone(&zero), &schema)?;
2455
2456 assert_not_nullable(when_then_else(&foo_is_not_null, &foo, &zero)?, &schema);
2457 assert_not_nullable(when_then_else(¬_foo_is_null, &foo, &zero)?, &schema);
2458 assert_not_nullable(when_then_else(&foo_eq_zero, &foo, &zero)?, &schema);
2459
2460 assert_not_nullable(
2461 when_then_else(
2462 &binary(
2463 Arc::clone(&foo_is_not_null),
2464 Operator::And,
2465 Arc::clone(&foo_eq_zero),
2466 &schema,
2467 )?,
2468 &foo,
2469 &zero,
2470 )?,
2471 &schema,
2472 );
2473
2474 assert_not_nullable(
2475 when_then_else(
2476 &binary(
2477 Arc::clone(&foo_eq_zero),
2478 Operator::And,
2479 Arc::clone(&foo_is_not_null),
2480 &schema,
2481 )?,
2482 &foo,
2483 &zero,
2484 )?,
2485 &schema,
2486 );
2487
2488 assert_not_nullable(
2489 when_then_else(
2490 &binary(
2491 Arc::clone(&foo_is_not_null),
2492 Operator::Or,
2493 Arc::clone(&foo_eq_zero),
2494 &schema,
2495 )?,
2496 &foo,
2497 &zero,
2498 )?,
2499 &schema,
2500 );
2501
2502 assert_not_nullable(
2503 when_then_else(
2504 &binary(
2505 Arc::clone(&foo_eq_zero),
2506 Operator::Or,
2507 Arc::clone(&foo_is_not_null),
2508 &schema,
2509 )?,
2510 &foo,
2511 &zero,
2512 )?,
2513 &schema,
2514 );
2515
2516 assert_nullability(
2517 when_then_else(
2518 &binary(
2519 Arc::clone(&foo_is_null),
2520 Operator::Or,
2521 Arc::clone(&foo_eq_zero),
2522 &schema,
2523 )?,
2524 &foo,
2525 &zero,
2526 )?,
2527 &schema,
2528 col_is_nullable,
2529 );
2530
2531 assert_nullability(
2532 when_then_else(
2533 &binary(
2534 binary(Arc::clone(&foo), Operator::Eq, Arc::clone(&zero), &schema)?,
2535 Operator::Or,
2536 Arc::clone(&foo_is_null),
2537 &schema,
2538 )?,
2539 &foo,
2540 &zero,
2541 )?,
2542 &schema,
2543 col_is_nullable,
2544 );
2545
2546 assert_not_nullable(
2547 when_then_else(
2548 &binary(
2549 binary(
2550 binary(
2551 Arc::clone(&foo),
2552 Operator::Eq,
2553 Arc::clone(&zero),
2554 &schema,
2555 )?,
2556 Operator::And,
2557 Arc::clone(&foo_is_not_null),
2558 &schema,
2559 )?,
2560 Operator::Or,
2561 binary(
2562 binary(
2563 Arc::clone(&foo),
2564 Operator::Eq,
2565 Arc::clone(&foo),
2566 &schema,
2567 )?,
2568 Operator::And,
2569 Arc::clone(&foo_is_not_null),
2570 &schema,
2571 )?,
2572 &schema,
2573 )?,
2574 &foo,
2575 &zero,
2576 )?,
2577 &schema,
2578 );
2579
2580 Ok(())
2581 }
2582
2583 fn assert_not_nullable(expr: Arc<dyn PhysicalExpr>, schema: &Schema) {
2584 assert!(!expr.nullable(schema).unwrap());
2585 }
2586
2587 fn assert_nullable(expr: Arc<dyn PhysicalExpr>, schema: &Schema) {
2588 assert!(expr.nullable(schema).unwrap());
2589 }
2590
2591 fn assert_nullability(expr: Arc<dyn PhysicalExpr>, schema: &Schema, nullable: bool) {
2592 if nullable {
2593 assert_nullable(expr, schema);
2594 } else {
2595 assert_not_nullable(expr, schema);
2596 }
2597 }
2598
2599 fn test_case_when_literal_lookup(
2602 values: ArrayRef,
2603 lookup_map: &[(ScalarValue, ScalarValue)],
2604 else_value: Option<ScalarValue>,
2605 expected: ArrayRef,
2606 ) {
2607 let schema = Schema::new(vec![Field::new(
2614 "a",
2615 values.data_type().clone(),
2616 values.is_nullable(),
2617 )]);
2618 let schema = Arc::new(schema);
2619
2620 let batch = RecordBatch::try_new(schema, vec![values])
2621 .expect("failed to create RecordBatch");
2622
2623 let schema = batch.schema_ref();
2624 let case = col("a", schema).expect("failed to create col");
2625
2626 let when_then = lookup_map
2627 .iter()
2628 .map(|(when, then)| {
2629 (
2630 Arc::new(Literal::new(when.clone())) as _,
2631 Arc::new(Literal::new(then.clone())) as _,
2632 )
2633 })
2634 .collect::<Vec<WhenThen>>();
2635
2636 let else_expr = else_value.map(|else_value| {
2637 Arc::new(Literal::new(else_value)) as Arc<dyn PhysicalExpr>
2638 });
2639 let expr = CaseExpr::try_new(Some(case), when_then, else_expr)
2640 .expect("failed to create case");
2641
2642 assert!(
2644 matches!(
2645 expr.eval_method,
2646 EvalMethod::WithExprScalarLookupTable { .. }
2647 ),
2648 "we should use the expected eval method"
2649 );
2650
2651 let actual = expr
2652 .evaluate(&batch)
2653 .expect("failed to evaluate case")
2654 .into_array(batch.num_rows())
2655 .expect("Failed to convert to array");
2656
2657 assert_eq!(
2658 actual.data_type(),
2659 expected.data_type(),
2660 "Data type mismatch"
2661 );
2662
2663 assert_eq!(
2664 actual.as_ref(),
2665 expected.as_ref(),
2666 "actual (left) does not match expected (right)"
2667 );
2668 }
2669
2670 fn create_lookup<When, Then>(
2671 when_then_pairs: impl IntoIterator<Item = (When, Then)>,
2672 ) -> Vec<(ScalarValue, ScalarValue)>
2673 where
2674 ScalarValue: From<When>,
2675 ScalarValue: From<Then>,
2676 {
2677 when_then_pairs
2678 .into_iter()
2679 .map(|(when, then)| (ScalarValue::from(when), ScalarValue::from(then)))
2680 .collect()
2681 }
2682
2683 fn create_input_and_expected<Input, Expected, InputFromItem, ExpectedFromItem>(
2684 input_and_expected_pairs: impl IntoIterator<Item = (InputFromItem, ExpectedFromItem)>,
2685 ) -> (Input, Expected)
2686 where
2687 Input: Array + From<Vec<InputFromItem>>,
2688 Expected: Array + From<Vec<ExpectedFromItem>>,
2689 {
2690 let (input_items, expected_items): (Vec<InputFromItem>, Vec<ExpectedFromItem>) =
2691 input_and_expected_pairs.into_iter().unzip();
2692
2693 (Input::from(input_items), Expected::from(expected_items))
2694 }
2695
2696 fn test_lookup_eval_with_and_without_else(
2697 lookup_map: &[(ScalarValue, ScalarValue)],
2698 input_values: ArrayRef,
2699 expected: StringArray,
2700 ) {
2701 test_case_when_literal_lookup(
2703 Arc::clone(&input_values),
2704 lookup_map,
2705 None,
2706 Arc::new(expected.clone()),
2707 );
2708
2709 let else_value = "___fallback___";
2711
2712 let expected_with_else = expected
2714 .iter()
2715 .map(|item| item.unwrap_or(else_value))
2716 .map(Some)
2717 .collect::<StringArray>();
2718
2719 test_case_when_literal_lookup(
2721 input_values,
2722 lookup_map,
2723 Some(ScalarValue::Utf8(Some(else_value.to_string()))),
2724 Arc::new(expected_with_else),
2725 );
2726 }
2727
2728 #[test]
2729 fn test_case_when_literal_lookup_int32_to_string() {
2730 let lookup_map = create_lookup([
2731 (Some(4), Some("four")),
2732 (Some(2), Some("two")),
2733 (Some(3), Some("three")),
2734 (Some(1), Some("one")),
2735 ]);
2736
2737 let (input_values, expected) =
2738 create_input_and_expected::<Int32Array, StringArray, _, _>([
2739 (1, Some("one")),
2740 (2, Some("two")),
2741 (3, Some("three")),
2742 (3, Some("three")),
2743 (2, Some("two")),
2744 (3, Some("three")),
2745 (5, None), (5, None), (3, Some("three")),
2748 (5, None), ]);
2750
2751 test_lookup_eval_with_and_without_else(
2752 &lookup_map,
2753 Arc::new(input_values),
2754 expected,
2755 );
2756 }
2757
2758 #[test]
2759 fn test_case_when_literal_lookup_none_case_should_never_match() {
2760 let lookup_map = create_lookup([
2761 (Some(4), Some("four")),
2762 (None, Some("none")),
2763 (Some(2), Some("two")),
2764 (Some(1), Some("one")),
2765 ]);
2766
2767 let (input_values, expected) =
2768 create_input_and_expected::<Int32Array, StringArray, _, _>([
2769 (Some(1), Some("one")),
2770 (Some(5), None), (None, None), (Some(2), Some("two")),
2773 (None, None), (None, None), (Some(2), Some("two")),
2776 (Some(5), None), ]);
2778
2779 test_lookup_eval_with_and_without_else(
2780 &lookup_map,
2781 Arc::new(input_values),
2782 expected,
2783 );
2784 }
2785
2786 #[test]
2787 fn test_case_when_literal_lookup_int32_to_string_with_duplicate_cases() {
2788 let lookup_map = create_lookup([
2789 (Some(4), Some("four")),
2790 (Some(4), Some("no 4")),
2791 (Some(2), Some("two")),
2792 (Some(2), Some("no 2")),
2793 (Some(3), Some("three")),
2794 (Some(3), Some("no 3")),
2795 (Some(2), Some("no 2")),
2796 (Some(4), Some("no 4")),
2797 (Some(2), Some("no 2")),
2798 (Some(3), Some("no 3")),
2799 (Some(4), Some("no 4")),
2800 (Some(2), Some("no 2")),
2801 (Some(3), Some("no 3")),
2802 (Some(3), Some("no 3")),
2803 ]);
2804
2805 let (input_values, expected) =
2806 create_input_and_expected::<Int32Array, StringArray, _, _>([
2807 (1, None), (2, Some("two")),
2809 (3, Some("three")),
2810 (3, Some("three")),
2811 (2, Some("two")),
2812 (3, Some("three")),
2813 (5, None), (5, None), (3, Some("three")),
2816 (5, None), ]);
2818
2819 test_lookup_eval_with_and_without_else(
2820 &lookup_map,
2821 Arc::new(input_values),
2822 expected,
2823 );
2824 }
2825
2826 #[test]
2827 fn test_case_when_literal_lookup_f32_to_string_with_special_values_and_duplicate_cases()
2828 {
2829 let lookup_map = create_lookup([
2830 (Some(4.0), Some("four point zero")),
2831 (Some(f32::NAN), Some("NaN")),
2832 (Some(3.2), Some("three point two")),
2833 (Some(f32::NAN), Some("should not use this NaN branch")),
2835 (Some(f32::INFINITY), Some("Infinity")),
2836 (Some(0.0), Some("zero")),
2837 (
2839 Some(f32::INFINITY),
2840 Some("should not use this Infinity branch"),
2841 ),
2842 (Some(1.1), Some("one point one")),
2843 ]);
2844
2845 let (input_values, expected) =
2846 create_input_and_expected::<Float32Array, StringArray, _, _>([
2847 (1.1, Some("one point one")),
2848 (f32::NAN, Some("NaN")),
2849 (3.2, Some("three point two")),
2850 (3.2, Some("three point two")),
2851 (0.0, Some("zero")),
2852 (f32::INFINITY, Some("Infinity")),
2853 (3.2, Some("three point two")),
2854 (f32::NEG_INFINITY, None), (f32::NEG_INFINITY, None), (3.2, Some("three point two")),
2857 (-0.0, None), ]);
2859
2860 test_lookup_eval_with_and_without_else(
2861 &lookup_map,
2862 Arc::new(input_values),
2863 expected,
2864 );
2865 }
2866
2867 #[test]
2868 fn test_case_when_literal_lookup_f16_to_string_with_special_values() {
2869 let lookup_map = create_lookup([
2870 (
2871 ScalarValue::Float16(Some(f16::from_f32(3.2))),
2872 Some("3 dot 2"),
2873 ),
2874 (ScalarValue::Float16(Some(f16::NAN)), Some("NaN")),
2875 (
2876 ScalarValue::Float16(Some(f16::from_f32(17.4))),
2877 Some("17 dot 4"),
2878 ),
2879 (ScalarValue::Float16(Some(f16::INFINITY)), Some("Infinity")),
2880 (ScalarValue::Float16(Some(f16::ZERO)), Some("zero")),
2881 ]);
2882
2883 let (input_values, expected) =
2884 create_input_and_expected::<Float16Array, StringArray, _, _>([
2885 (f16::from_f32(3.2), Some("3 dot 2")),
2886 (f16::NAN, Some("NaN")),
2887 (f16::from_f32(17.4), Some("17 dot 4")),
2888 (f16::from_f32(17.4), Some("17 dot 4")),
2889 (f16::INFINITY, Some("Infinity")),
2890 (f16::from_f32(17.4), Some("17 dot 4")),
2891 (f16::NEG_INFINITY, None), (f16::NEG_INFINITY, None), (f16::from_f32(17.4), Some("17 dot 4")),
2894 (f16::NEG_ZERO, None), ]);
2896
2897 test_lookup_eval_with_and_without_else(
2898 &lookup_map,
2899 Arc::new(input_values),
2900 expected,
2901 );
2902 }
2903
2904 #[test]
2905 fn test_case_when_literal_lookup_f32_to_string_with_special_values() {
2906 let lookup_map = create_lookup([
2907 (3.2, Some("3 dot 2")),
2908 (f32::NAN, Some("NaN")),
2909 (17.4, Some("17 dot 4")),
2910 (f32::INFINITY, Some("Infinity")),
2911 (f32::ZERO, Some("zero")),
2912 ]);
2913
2914 let (input_values, expected) =
2915 create_input_and_expected::<Float32Array, StringArray, _, _>([
2916 (3.2, Some("3 dot 2")),
2917 (f32::NAN, Some("NaN")),
2918 (17.4, Some("17 dot 4")),
2919 (17.4, Some("17 dot 4")),
2920 (f32::INFINITY, Some("Infinity")),
2921 (17.4, Some("17 dot 4")),
2922 (f32::NEG_INFINITY, None), (f32::NEG_INFINITY, None), (17.4, Some("17 dot 4")),
2925 (-0.0, None), ]);
2927
2928 test_lookup_eval_with_and_without_else(
2929 &lookup_map,
2930 Arc::new(input_values),
2931 expected,
2932 );
2933 }
2934
2935 #[test]
2936 fn test_case_when_literal_lookup_f64_to_string_with_special_values() {
2937 let lookup_map = create_lookup([
2938 (3.2, Some("3 dot 2")),
2939 (f64::NAN, Some("NaN")),
2940 (17.4, Some("17 dot 4")),
2941 (f64::INFINITY, Some("Infinity")),
2942 (f64::ZERO, Some("zero")),
2943 ]);
2944
2945 let (input_values, expected) =
2946 create_input_and_expected::<Float64Array, StringArray, _, _>([
2947 (3.2, Some("3 dot 2")),
2948 (f64::NAN, Some("NaN")),
2949 (17.4, Some("17 dot 4")),
2950 (17.4, Some("17 dot 4")),
2951 (f64::INFINITY, Some("Infinity")),
2952 (17.4, Some("17 dot 4")),
2953 (f64::NEG_INFINITY, None), (f64::NEG_INFINITY, None), (17.4, Some("17 dot 4")),
2956 (-0.0, None), ]);
2958
2959 test_lookup_eval_with_and_without_else(
2960 &lookup_map,
2961 Arc::new(input_values),
2962 expected,
2963 );
2964 }
2965
2966 #[test]
2968 fn test_decimal_with_non_default_precision_and_scale() {
2969 let lookup_map = create_lookup([
2970 (ScalarValue::Decimal32(Some(4), 3, 2), Some("four")),
2971 (ScalarValue::Decimal32(Some(2), 3, 2), Some("two")),
2972 (ScalarValue::Decimal32(Some(3), 3, 2), Some("three")),
2973 (ScalarValue::Decimal32(Some(1), 3, 2), Some("one")),
2974 ]);
2975
2976 let (input_values, expected) =
2977 create_input_and_expected::<Decimal32Array, StringArray, _, _>([
2978 (1, Some("one")),
2979 (2, Some("two")),
2980 (3, Some("three")),
2981 (3, Some("three")),
2982 (2, Some("two")),
2983 (3, Some("three")),
2984 (5, None), (5, None), (3, Some("three")),
2987 (5, None), ]);
2989
2990 let input_values = input_values
2991 .with_precision_and_scale(3, 2)
2992 .expect("must be able to set precision and scale");
2993
2994 test_lookup_eval_with_and_without_else(
2995 &lookup_map,
2996 Arc::new(input_values),
2997 expected,
2998 );
2999 }
3000
3001 #[test]
3003 fn test_timestamp_with_non_default_timezone() {
3004 let timezone: Option<Arc<str>> = Some("-10:00".into());
3005 let lookup_map = create_lookup([
3006 (
3007 ScalarValue::TimestampMillisecond(Some(4), timezone.clone()),
3008 Some("four"),
3009 ),
3010 (
3011 ScalarValue::TimestampMillisecond(Some(2), timezone.clone()),
3012 Some("two"),
3013 ),
3014 (
3015 ScalarValue::TimestampMillisecond(Some(3), timezone.clone()),
3016 Some("three"),
3017 ),
3018 (
3019 ScalarValue::TimestampMillisecond(Some(1), timezone.clone()),
3020 Some("one"),
3021 ),
3022 ]);
3023
3024 let (input_values, expected) =
3025 create_input_and_expected::<TimestampMillisecondArray, StringArray, _, _>([
3026 (1, Some("one")),
3027 (2, Some("two")),
3028 (3, Some("three")),
3029 (3, Some("three")),
3030 (2, Some("two")),
3031 (3, Some("three")),
3032 (5, None), (5, None), (3, Some("three")),
3035 (5, None), ]);
3037
3038 let input_values = input_values.with_timezone_opt(timezone);
3039
3040 test_lookup_eval_with_and_without_else(
3041 &lookup_map,
3042 Arc::new(input_values),
3043 expected,
3044 );
3045 }
3046
3047 #[test]
3048 fn test_with_strings_to_int32() {
3049 let lookup_map = create_lookup([
3050 (Some("why"), Some(42)),
3051 (Some("what"), Some(22)),
3052 (Some("when"), Some(17)),
3053 ]);
3054
3055 let (input_values, expected) =
3056 create_input_and_expected::<StringArray, Int32Array, _, _>([
3057 (Some("why"), Some(42)),
3058 (Some("5"), None), (None, None), (Some("what"), Some(22)),
3061 (None, None), (None, None), (Some("what"), Some(22)),
3064 (Some("5"), None), ]);
3066
3067 let input_values = Arc::new(input_values) as ArrayRef;
3068
3069 test_case_when_literal_lookup(
3071 Arc::clone(&input_values),
3072 &lookup_map,
3073 None,
3074 Arc::new(expected.clone()),
3075 );
3076
3077 let else_value = 101;
3079
3080 let expected_with_else = expected
3082 .iter()
3083 .map(|item| item.unwrap_or(else_value))
3084 .map(Some)
3085 .collect::<Int32Array>();
3086
3087 test_case_when_literal_lookup(
3089 input_values,
3090 &lookup_map,
3091 Some(ScalarValue::Int32(Some(else_value))),
3092 Arc::new(expected_with_else),
3093 );
3094 }
3095}