1mod literal_lookup_table;
19
20use super::{Column, Literal};
21use crate::PhysicalExpr;
22use crate::expressions::{LambdaVariable, lit, try_cast};
23use arrow::array::*;
24use arrow::compute::kernels::zip::zip;
25use arrow::compute::{
26 FilterBuilder, FilterPredicate, is_not_null, not, nullif, prep_null_mask_filter,
27};
28use arrow::datatypes::{DataType, Schema, UInt32Type, UnionMode};
29use arrow::error::ArrowError;
30use datafusion_common::cast::as_boolean_array;
31use datafusion_common::{
32 DataFusionError, Result, ScalarValue, assert_or_internal_err, exec_err,
33 internal_datafusion_err, internal_err,
34};
35use datafusion_expr::ColumnarValue;
36use indexmap::IndexMap;
37use std::borrow::Cow;
38use std::collections::BTreeSet;
39use std::hash::Hash;
40use std::sync::Arc;
41
42use crate::expressions::case::literal_lookup_table::LiteralLookupTable;
43use arrow::compute::kernels::merge::{MergeIndex, merge, merge_n};
44use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
45use datafusion_physical_expr_common::datum::compare_with_eq;
46use datafusion_physical_expr_common::utils::scatter;
47use itertools::Itertools;
48use std::fmt::{Debug, Formatter};
49
50pub(super) type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
51
52#[derive(Debug, Hash, PartialEq, Eq)]
53enum EvalMethod {
54 NoExpression(ProjectedCaseBody),
59 WithExpression(ProjectedCaseBody),
65 InfallibleExprOrNull,
71 ScalarOrScalar,
76 ExpressionOrExpression(ProjectedCaseBody),
85
86 WithExprScalarLookupTable(LiteralLookupTable),
90}
91
92impl Hash for LiteralLookupTable {
99 fn hash<H: std::hash::Hasher>(&self, _state: &mut H) {}
100}
101
102impl PartialEq for LiteralLookupTable {
109 fn eq(&self, _other: &Self) -> bool {
110 true
111 }
112}
113
114impl Eq for LiteralLookupTable {}
115
116#[derive(Debug, Hash, PartialEq, Eq)]
119struct CaseBody {
120 expr: Option<Arc<dyn PhysicalExpr>>,
122 when_then_expr: Vec<WhenThen>,
124 else_expr: Option<Arc<dyn PhysicalExpr>>,
126}
127
128impl CaseBody {
129 fn project(&self) -> Result<ProjectedCaseBody> {
131 let mut used_column_indices = BTreeSet::<usize>::new();
134 let mut collect_column_indices = |expr: &Arc<dyn PhysicalExpr>| {
135 expr.apply(|expr| {
136 if let Some(column) = expr.downcast_ref::<Column>() {
137 used_column_indices.insert(column.index());
138 } else if let Some(lambda_variable) =
139 expr.downcast_ref::<LambdaVariable>()
140 {
141 used_column_indices.insert(lambda_variable.index());
142 }
143 Ok(TreeNodeRecursion::Continue)
144 })
145 .expect("Closure cannot fail");
146 };
147
148 if let Some(e) = &self.expr {
149 collect_column_indices(e);
150 }
151 self.when_then_expr.iter().for_each(|(w, t)| {
152 collect_column_indices(w);
153 collect_column_indices(t);
154 });
155 if let Some(e) = &self.else_expr {
156 collect_column_indices(e);
157 }
158
159 let column_index_map = used_column_indices
161 .iter()
162 .enumerate()
163 .map(|(projected, original)| (*original, projected))
164 .collect::<IndexMap<usize, usize>>();
165
166 let project = |expr: &Arc<dyn PhysicalExpr>| -> Result<Arc<dyn PhysicalExpr>> {
169 Arc::clone(expr)
170 .transform_down(|e| {
171 if let Some(column) = e.downcast_ref::<Column>() {
172 let original = column.index();
173 let projected = *column_index_map.get(&original).unwrap();
174 if projected != original {
175 return Ok(Transformed::yes(Arc::new(Column::new(
176 column.name(),
177 projected,
178 ))));
179 }
180 } else if let Some(lambda_variable) =
181 e.downcast_ref::<LambdaVariable>()
182 {
183 let original = lambda_variable.index();
184 let projected = *column_index_map.get(&original).unwrap();
185 if projected != original {
186 return Ok(Transformed::yes(Arc::new(LambdaVariable::new(
187 projected,
188 Arc::clone(lambda_variable.field()),
189 ))));
190 }
191 }
192 Ok(Transformed::no(e))
193 })
194 .map(|t| t.data)
195 };
196
197 let projected_body = CaseBody {
198 expr: self.expr.as_ref().map(project).transpose()?,
199 when_then_expr: self
200 .when_then_expr
201 .iter()
202 .map(|(e, t)| Ok((project(e)?, project(t)?)))
203 .collect::<Result<Vec<_>>>()?,
204 else_expr: self.else_expr.as_ref().map(project).transpose()?,
205 };
206
207 let projection = column_index_map
209 .iter()
210 .sorted_by_key(|(_, v)| **v)
211 .map(|(k, _)| *k)
212 .collect::<Vec<_>>();
213
214 Ok(ProjectedCaseBody {
215 projection,
216 body: projected_body,
217 })
218 }
219}
220
221#[derive(Debug, Hash, PartialEq, Eq)]
249struct ProjectedCaseBody {
250 projection: Vec<usize>,
251 body: CaseBody,
252}
253
254#[derive(Debug)]
272pub struct CaseExpr {
273 body: CaseBody,
275 eval_method: EvalMethod,
277}
278
279impl Hash for CaseExpr {
283 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
284 self.body.hash(state);
285 }
286}
287
288impl PartialEq for CaseExpr {
289 fn eq(&self, other: &Self) -> bool {
290 self.body == other.body
291 }
292}
293
294impl Eq for CaseExpr {}
295
296impl std::fmt::Display for CaseExpr {
297 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
298 write!(f, "CASE ")?;
299 if let Some(e) = &self.body.expr {
300 write!(f, "{e} ")?;
301 }
302 for (w, t) in &self.body.when_then_expr {
303 write!(f, "WHEN {w} THEN {t} ")?;
304 }
305 if let Some(e) = &self.body.else_expr {
306 write!(f, "ELSE {e} ")?;
307 }
308 write!(f, "END")
309 }
310}
311
312fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) -> bool {
318 expr.is::<Column>()
319}
320
321fn create_filter(predicate: &BooleanArray, optimize: bool) -> FilterPredicate {
323 let mut filter_builder = FilterBuilder::new(predicate);
324 if optimize {
325 filter_builder = filter_builder.optimize();
327 }
328 filter_builder.build()
329}
330
331fn multiple_arrays(data_type: &DataType) -> bool {
332 match data_type {
333 DataType::Struct(fields) => {
334 fields.len() > 1
335 || fields.len() == 1 && multiple_arrays(fields[0].data_type())
336 }
337 DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
338 _ => false,
339 }
340}
341
342fn filter_record_batch(
345 record_batch: &RecordBatch,
346 filter: &FilterPredicate,
347) -> std::result::Result<RecordBatch, ArrowError> {
348 let filtered_columns = record_batch
349 .columns()
350 .iter()
351 .map(|a| filter_array(a, filter))
352 .collect::<std::result::Result<Vec<_>, _>>()?;
353 unsafe {
359 Ok(RecordBatch::new_unchecked(
360 record_batch.schema(),
361 filtered_columns,
362 filter.count(),
363 ))
364 }
365}
366
367#[inline(always)]
372fn filter_array(
373 array: &dyn Array,
374 filter: &FilterPredicate,
375) -> std::result::Result<ArrayRef, ArrowError> {
376 filter.filter(array)
377}
378
379#[derive(Copy, Clone, PartialEq, Eq)]
384struct PartialResultIndex {
385 index: u32,
386}
387
388const NONE_VALUE: u32 = u32::MAX;
389
390impl PartialResultIndex {
391 fn none() -> Self {
393 Self { index: NONE_VALUE }
394 }
395
396 fn zero() -> Self {
397 Self { index: 0 }
398 }
399
400 fn try_new(index: usize) -> Result<Self> {
405 let Ok(index) = u32::try_from(index) else {
406 return internal_err!("Partial result index exceeds limit");
407 };
408
409 assert_or_internal_err!(
410 index != NONE_VALUE,
411 "Partial result index exceeds limit"
412 );
413
414 Ok(Self { index })
415 }
416
417 fn is_none(&self) -> bool {
419 self.index == NONE_VALUE
420 }
421}
422
423impl MergeIndex for PartialResultIndex {
424 fn index(&self) -> Option<usize> {
426 if self.is_none() {
427 None
428 } else {
429 Some(self.index as usize)
430 }
431 }
432}
433
434impl Debug for PartialResultIndex {
435 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
436 if self.is_none() {
437 write!(f, "null")
438 } else {
439 write!(f, "{}", self.index)
440 }
441 }
442}
443
444enum ResultState {
445 Empty,
447 Partial {
449 arrays: Vec<ArrayRef>,
452 indices: Vec<PartialResultIndex>,
454 },
455 Complete(ColumnarValue),
458}
459
460struct ResultBuilder {
470 data_type: DataType,
471 row_count: usize,
473 state: ResultState,
474}
475
476impl ResultBuilder {
477 fn new(data_type: &DataType, row_count: usize) -> Self {
481 Self {
482 data_type: data_type.clone(),
483 row_count,
484 state: ResultState::Empty,
485 }
486 }
487
488 fn add_branch_result(
521 &mut self,
522 row_indices: &ArrayRef,
523 value: ColumnarValue,
524 ) -> Result<()> {
525 match value {
526 ColumnarValue::Array(a) => {
527 if a.len() != row_indices.len() {
528 internal_err!("Array length must match row indices length")
529 } else if row_indices.len() == self.row_count {
530 self.set_complete_result(ColumnarValue::Array(a))
531 } else {
532 self.add_partial_result(row_indices, a)
533 }
534 }
535 ColumnarValue::Scalar(s) => {
536 if row_indices.len() == self.row_count {
537 self.set_complete_result(ColumnarValue::Scalar(s))
538 } else {
539 self.add_partial_result(
540 row_indices,
541 s.to_array_of_size(row_indices.len())?,
542 )
543 }
544 }
545 }
546 }
547
548 fn add_partial_result(
554 &mut self,
555 row_indices: &ArrayRef,
556 row_values: ArrayRef,
557 ) -> Result<()> {
558 assert_or_internal_err!(
559 row_indices.null_count() == 0,
560 "Row indices must not contain nulls"
561 );
562
563 match &mut self.state {
564 ResultState::Empty => {
565 let array_index = PartialResultIndex::zero();
566 let mut indices = vec![PartialResultIndex::none(); self.row_count];
567 for row_ix in row_indices.as_primitive::<UInt32Type>().values().iter() {
568 indices[*row_ix as usize] = array_index;
569 }
570
571 self.state = ResultState::Partial {
572 arrays: vec![row_values],
573 indices,
574 };
575
576 Ok(())
577 }
578 ResultState::Partial { arrays, indices } => {
579 let array_index = PartialResultIndex::try_new(arrays.len())?;
580
581 arrays.push(row_values);
582
583 for row_ix in row_indices.as_primitive::<UInt32Type>().values().iter() {
584 #[cfg(debug_assertions)]
588 assert_or_internal_err!(
589 indices[*row_ix as usize].is_none(),
590 "Duplicate value for row {}",
591 *row_ix
592 );
593
594 indices[*row_ix as usize] = array_index;
595 }
596 Ok(())
597 }
598 ResultState::Complete(_) => internal_err!(
599 "Cannot add a partial result when complete result is already set"
600 ),
601 }
602 }
603
604 fn set_complete_result(&mut self, value: ColumnarValue) -> Result<()> {
610 match &self.state {
611 ResultState::Empty => {
612 self.state = ResultState::Complete(value);
613 Ok(())
614 }
615 ResultState::Partial { .. } => {
616 internal_err!(
617 "Cannot set a complete result when there are already partial results"
618 )
619 }
620 ResultState::Complete(_) => internal_err!("Complete result already set"),
621 }
622 }
623
624 fn finish(self) -> Result<ColumnarValue> {
626 match self.state {
627 ResultState::Empty => {
628 Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
632 &self.data_type,
633 )?))
634 }
635 ResultState::Partial { arrays, indices } => {
636 let array_refs = arrays.iter().map(|a| a.as_ref()).collect::<Vec<_>>();
638 Ok(ColumnarValue::Array(merge_n(&array_refs, &indices)?))
639 }
640 ResultState::Complete(v) => {
641 Ok(v)
643 }
644 }
645 }
646}
647
648impl CaseExpr {
649 pub fn try_new(
651 expr: Option<Arc<dyn PhysicalExpr>>,
652 when_then_expr: Vec<WhenThen>,
653 else_expr: Option<Arc<dyn PhysicalExpr>>,
654 ) -> Result<Self> {
655 let else_expr = match &else_expr {
658 Some(e) => match e.downcast_ref::<Literal>() {
659 Some(lit) if lit.value().is_null() => None,
660 _ => else_expr,
661 },
662 _ => else_expr,
663 };
664
665 if when_then_expr.is_empty() {
666 return exec_err!("There must be at least one WHEN clause");
667 }
668
669 let body = CaseBody {
670 expr,
671 when_then_expr,
672 else_expr,
673 };
674
675 let eval_method = Self::find_best_eval_method(&body)?;
676
677 Ok(Self { body, eval_method })
678 }
679
680 fn find_best_eval_method(body: &CaseBody) -> Result<EvalMethod> {
681 if body.expr.is_some() {
682 if let Some(mapping) = LiteralLookupTable::maybe_new(body) {
683 return Ok(EvalMethod::WithExprScalarLookupTable(mapping));
684 }
685
686 return Ok(EvalMethod::WithExpression(body.project()?));
687 }
688
689 Ok(
690 if body.when_then_expr.len() == 1
691 && is_cheap_and_infallible(&(body.when_then_expr[0].1))
692 && body.else_expr.is_none()
693 {
694 EvalMethod::InfallibleExprOrNull
695 } else if body.when_then_expr.len() == 1
696 && body.when_then_expr[0].1.is::<Literal>()
697 && body.else_expr.is_some()
698 && body.else_expr.as_ref().unwrap().is::<Literal>()
699 {
700 EvalMethod::ScalarOrScalar
701 } else if body.when_then_expr.len() == 1 {
702 EvalMethod::ExpressionOrExpression(body.project()?)
703 } else {
704 EvalMethod::NoExpression(body.project()?)
705 },
706 )
707 }
708
709 pub fn expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
711 self.body.expr.as_ref()
712 }
713
714 pub fn when_then_expr(&self) -> &[WhenThen] {
716 &self.body.when_then_expr
717 }
718
719 pub fn else_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
721 self.body.else_expr.as_ref()
722 }
723}
724
725impl CaseBody {
726 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
727 let mut data_type = DataType::Null;
730 for i in 0..self.when_then_expr.len() {
731 data_type = self.when_then_expr[i].1.data_type(input_schema)?;
732 if !data_type.equals_datatype(&DataType::Null) {
733 break;
734 }
735 }
736 if data_type.equals_datatype(&DataType::Null)
738 && let Some(e) = &self.else_expr
739 {
740 data_type = e.data_type(input_schema)?;
741 }
742
743 Ok(data_type)
744 }
745
746 fn case_when_with_expr(
748 &self,
749 batch: &RecordBatch,
750 return_type: &DataType,
751 ) -> Result<ColumnarValue> {
752 let mut result_builder = ResultBuilder::new(return_type, batch.num_rows());
753
754 let mut remainder_rows: ArrayRef =
756 Arc::new(UInt32Array::from_iter_values(0..batch.num_rows() as u32));
757 let mut remainder_batch = Cow::Borrowed(batch);
759
760 let mut base_values = self
762 .expr
763 .as_ref()
764 .unwrap()
765 .evaluate(batch)?
766 .into_array(batch.num_rows())?;
767
768 let base_null_count = base_values.logical_null_count();
773 if base_null_count > 0 {
774 let base_not_nulls = is_not_null(base_values.as_ref())?;
778 let base_all_null = base_null_count == remainder_batch.num_rows();
779
780 if let Some(e) = &self.else_expr {
783 let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
784
785 if base_all_null {
786 let nulls_value = expr.evaluate(&remainder_batch)?;
788 result_builder.add_branch_result(&remainder_rows, nulls_value)?;
789 } else {
790 let nulls_filter = create_filter(¬(&base_not_nulls)?, true);
792 let nulls_batch =
793 filter_record_batch(&remainder_batch, &nulls_filter)?;
794 let nulls_rows = filter_array(&remainder_rows, &nulls_filter)?;
795 let nulls_value = expr.evaluate(&nulls_batch)?;
796 result_builder.add_branch_result(&nulls_rows, nulls_value)?;
797 }
798 }
799
800 if base_all_null {
802 return result_builder.finish();
803 }
804
805 let not_null_filter = create_filter(&base_not_nulls, true);
807 remainder_batch =
808 Cow::Owned(filter_record_batch(&remainder_batch, ¬_null_filter)?);
809 remainder_rows = filter_array(&remainder_rows, ¬_null_filter)?;
810 base_values = filter_array(&base_values, ¬_null_filter)?;
811 }
812
813 let base_value_is_nested = base_values.data_type().is_nested();
816
817 for i in 0..self.when_then_expr.len() {
818 let when_expr = &self.when_then_expr[i].0;
821 let when_value = match when_expr.evaluate(&remainder_batch)? {
822 ColumnarValue::Array(a) => {
823 compare_with_eq(&a, &base_values, base_value_is_nested)
824 }
825 ColumnarValue::Scalar(s) => {
826 compare_with_eq(&s.to_scalar()?, &base_values, base_value_is_nested)
827 }
828 }?;
829
830 if !when_value.has_true() {
834 continue;
835 }
836
837 if when_value.null_count() == 0 && !when_value.has_false() {
839 let then_expression = &self.when_then_expr[i].1;
840 let then_value = then_expression.evaluate(&remainder_batch)?;
841 result_builder.add_branch_result(&remainder_rows, then_value)?;
842 return result_builder.finish();
843 }
844
845 let then_filter = create_filter(&when_value, true);
851 let then_batch = filter_record_batch(&remainder_batch, &then_filter)?;
852 let then_rows = filter_array(&remainder_rows, &then_filter)?;
853
854 let then_expression = &self.when_then_expr[i].1;
855 let then_value = then_expression.evaluate(&then_batch)?;
856 result_builder.add_branch_result(&then_rows, then_value)?;
857
858 if self.else_expr.is_none() && i == self.when_then_expr.len() - 1 {
861 return result_builder.finish();
862 }
863
864 let next_selection = match when_value.null_count() {
866 0 => not(&when_value),
867 _ => {
868 not(&prep_null_mask_filter(&when_value))
871 }
872 }?;
873 let next_filter = create_filter(&next_selection, true);
874 remainder_batch =
875 Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?);
876 remainder_rows = filter_array(&remainder_rows, &next_filter)?;
877 base_values = filter_array(&base_values, &next_filter)?;
878 }
879
880 if let Some(e) = &self.else_expr {
883 let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
885 let else_value = expr.evaluate(&remainder_batch)?;
886 result_builder.add_branch_result(&remainder_rows, else_value)?;
887 }
888
889 result_builder.finish()
890 }
891
892 fn case_when_no_expr(
894 &self,
895 batch: &RecordBatch,
896 return_type: &DataType,
897 ) -> Result<ColumnarValue> {
898 let mut result_builder = ResultBuilder::new(return_type, batch.num_rows());
899
900 let mut remainder_rows: ArrayRef =
902 Arc::new(UInt32Array::from_iter(0..batch.num_rows() as u32));
903 let mut remainder_batch = Cow::Borrowed(batch);
905
906 for i in 0..self.when_then_expr.len() {
907 let when_predicate = &self.when_then_expr[i].0;
910 let when_value = when_predicate
911 .evaluate(&remainder_batch)?
912 .into_array(remainder_batch.num_rows())?;
913 let when_value = as_boolean_array(&when_value).map_err(|_| {
914 internal_datafusion_err!("WHEN expression did not return a BooleanArray")
915 })?;
916
917 if !when_value.has_true() {
921 continue;
922 }
923
924 if when_value.null_count() == 0 && !when_value.has_false() {
926 let then_expression = &self.when_then_expr[i].1;
927 let then_value = then_expression.evaluate(&remainder_batch)?;
928 result_builder.add_branch_result(&remainder_rows, then_value)?;
929 return result_builder.finish();
930 }
931
932 let then_filter = create_filter(when_value, true);
938 let then_batch = filter_record_batch(&remainder_batch, &then_filter)?;
939 let then_rows = filter_array(&remainder_rows, &then_filter)?;
940
941 let then_expression = &self.when_then_expr[i].1;
942 let then_value = then_expression.evaluate(&then_batch)?;
943 result_builder.add_branch_result(&then_rows, then_value)?;
944
945 if self.else_expr.is_none() && i == self.when_then_expr.len() - 1 {
948 return result_builder.finish();
949 }
950
951 let next_selection = match when_value.null_count() {
953 0 => not(when_value),
954 _ => {
955 not(&prep_null_mask_filter(when_value))
958 }
959 }?;
960 let next_filter = create_filter(&next_selection, true);
961 remainder_batch =
962 Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?);
963 remainder_rows = filter_array(&remainder_rows, &next_filter)?;
964 }
965
966 if let Some(e) = &self.else_expr {
969 let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
971 let else_value = expr.evaluate(&remainder_batch)?;
972 result_builder.add_branch_result(&remainder_rows, else_value)?;
973 }
974
975 result_builder.finish()
976 }
977
978 fn expr_or_expr(
980 &self,
981 batch: &RecordBatch,
982 when_value: &BooleanArray,
983 ) -> Result<ColumnarValue> {
984 let when_value = match when_value.null_count() {
985 0 => Cow::Borrowed(when_value),
986 _ => {
987 Cow::Owned(prep_null_mask_filter(when_value))
989 }
990 };
991
992 let optimize_filter = batch.num_columns() > 1
993 || (batch.num_columns() == 1 && multiple_arrays(batch.column(0).data_type()));
994
995 let when_filter = create_filter(&when_value, optimize_filter);
996 let then_batch = filter_record_batch(batch, &when_filter)?;
997 let then_value = self.when_then_expr[0].1.evaluate(&then_batch)?;
998
999 match &self.else_expr {
1000 None => {
1001 let then_array = then_value.to_array(when_value.true_count())?;
1002 scatter(&when_value, then_array.as_ref()).map(ColumnarValue::Array)
1003 }
1004 Some(else_expr) => {
1005 let else_selection = not(&when_value)?;
1006 let else_filter = create_filter(&else_selection, optimize_filter);
1007 let else_batch = filter_record_batch(batch, &else_filter)?;
1008
1009 let return_type = self.data_type(&batch.schema())?;
1011 let else_expr =
1012 try_cast(Arc::clone(else_expr), &batch.schema(), return_type.clone())
1013 .unwrap_or_else(|_| Arc::clone(else_expr));
1014
1015 let else_value = else_expr.evaluate(&else_batch)?;
1016
1017 Ok(ColumnarValue::Array(match (then_value, else_value) {
1018 (ColumnarValue::Array(t), ColumnarValue::Array(e)) => {
1019 merge(&when_value, &t, &e)
1020 }
1021 (ColumnarValue::Scalar(t), ColumnarValue::Array(e)) => {
1022 merge(&when_value, &t.to_scalar()?, &e)
1023 }
1024 (ColumnarValue::Array(t), ColumnarValue::Scalar(e)) => {
1025 merge(&when_value, &t, &e.to_scalar()?)
1026 }
1027 (ColumnarValue::Scalar(t), ColumnarValue::Scalar(e)) => {
1028 merge(&when_value, &t.to_scalar()?, &e.to_scalar()?)
1029 }
1030 }?))
1031 }
1032 }
1033 }
1034}
1035
1036impl CaseExpr {
1037 fn case_when_with_expr(
1045 &self,
1046 batch: &RecordBatch,
1047 projected: &ProjectedCaseBody,
1048 ) -> Result<ColumnarValue> {
1049 let return_type = self.data_type(&batch.schema())?;
1050 let projection = projected
1052 .projection
1053 .iter()
1054 .copied()
1055 .filter(|index| *index < batch.num_columns())
1056 .collect::<Vec<_>>();
1057 if projection.len() < batch.num_columns() {
1058 let projected_batch = batch.project(&projection)?;
1059 projected
1060 .body
1061 .case_when_with_expr(&projected_batch, &return_type)
1062 } else {
1063 self.body.case_when_with_expr(batch, &return_type)
1064 }
1065 }
1066
1067 fn case_when_no_expr(
1075 &self,
1076 batch: &RecordBatch,
1077 projected: &ProjectedCaseBody,
1078 ) -> Result<ColumnarValue> {
1079 let return_type = self.data_type(&batch.schema())?;
1080 let projection = projected
1082 .projection
1083 .iter()
1084 .copied()
1085 .filter(|index| *index < batch.num_columns())
1086 .collect::<Vec<_>>();
1087 if projection.len() < batch.num_columns() {
1088 let projected_batch = batch.project(&projection)?;
1089 projected
1090 .body
1091 .case_when_no_expr(&projected_batch, &return_type)
1092 } else {
1093 self.body.case_when_no_expr(batch, &return_type)
1094 }
1095 }
1096
1097 fn case_column_or_null(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
1107 let when_expr = &self.body.when_then_expr[0].0;
1108 let then_expr = &self.body.when_then_expr[0].1;
1109
1110 match when_expr.evaluate(batch)? {
1111 ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))) => {
1113 then_expr.evaluate(batch)
1114 }
1115 ColumnarValue::Scalar(_) => {
1117 ScalarValue::try_from(self.data_type(&batch.schema())?)
1119 .map(ColumnarValue::Scalar)
1120 }
1121 ColumnarValue::Array(bit_mask) => {
1123 let bit_mask = bit_mask
1124 .as_any()
1125 .downcast_ref::<BooleanArray>()
1126 .expect("predicate should evaluate to a boolean array");
1127 let bit_mask = match bit_mask.null_count() {
1129 0 => not(bit_mask)?,
1130 _ => not(&prep_null_mask_filter(bit_mask))?,
1131 };
1132 match then_expr.evaluate(batch)? {
1133 ColumnarValue::Array(array) => {
1134 Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?))
1135 }
1136 ColumnarValue::Scalar(_) => {
1137 internal_err!("expression did not evaluate to an array")
1138 }
1139 }
1140 }
1141 }
1142 }
1143
1144 fn scalar_or_scalar(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
1145 let return_type = self.data_type(&batch.schema())?;
1146
1147 let when_value = self.body.when_then_expr[0].0.evaluate(batch)?;
1149 let when_value = when_value.into_array(batch.num_rows())?;
1150 let when_value = as_boolean_array(&when_value).map_err(|_| {
1151 internal_datafusion_err!("WHEN expression did not return a BooleanArray")
1152 })?;
1153
1154 let when_value = match when_value.null_count() {
1156 0 => Cow::Borrowed(when_value),
1157 _ => Cow::Owned(prep_null_mask_filter(when_value)),
1158 };
1159
1160 let then_value = self.body.when_then_expr[0].1.evaluate(batch)?;
1162 let then_value = Scalar::new(then_value.into_array(1)?);
1163
1164 let Some(e) = &self.body.else_expr else {
1165 return internal_err!("expression did not evaluate to an array");
1166 };
1167 let expr = try_cast(Arc::clone(e), &batch.schema(), return_type)?;
1169 let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?);
1170 Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?))
1171 }
1172
1173 fn expr_or_expr(
1174 &self,
1175 batch: &RecordBatch,
1176 projected: &ProjectedCaseBody,
1177 ) -> Result<ColumnarValue> {
1178 let when_value = self.body.when_then_expr[0].0.evaluate(batch)?;
1180 let when_value = when_value.into_array(1)?;
1184 let when_value = as_boolean_array(&when_value).map_err(|e| {
1185 DataFusionError::Context(
1186 "WHEN expression did not return a BooleanArray".to_string(),
1187 Box::new(e),
1188 )
1189 })?;
1190
1191 if when_value.null_count() == 0 && !when_value.has_false() {
1192 self.body.when_then_expr[0].1.evaluate(batch)
1194 } else if !when_value.has_true() {
1195 match &self.body.else_expr {
1197 Some(else_expr) => else_expr.evaluate(batch),
1198 None => {
1199 let return_type = self.data_type(&batch.schema())?;
1200 Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
1201 &return_type,
1202 )?))
1203 }
1204 }
1205 } else {
1206 let projection = projected
1208 .projection
1209 .iter()
1210 .copied()
1211 .filter(|index| *index < batch.num_columns())
1212 .collect::<Vec<_>>();
1213 if projection.len() < batch.num_columns() {
1214 let projected_batch = batch.project(&projection)?;
1217 projected.body.expr_or_expr(&projected_batch, when_value)
1218 } else {
1219 self.body.expr_or_expr(batch, when_value)
1221 }
1222 }
1223 }
1224
1225 fn with_lookup_table(
1226 &self,
1227 batch: &RecordBatch,
1228 lookup_table: &LiteralLookupTable,
1229 ) -> Result<ColumnarValue> {
1230 let expr = self.body.expr.as_ref().unwrap();
1231 let evaluated_expression = expr.evaluate(batch)?;
1232
1233 let is_scalar = matches!(evaluated_expression, ColumnarValue::Scalar(_));
1234 let evaluated_expression = evaluated_expression.to_array(1)?;
1235
1236 let values = lookup_table.map_keys_to_values(&evaluated_expression)?;
1237
1238 let result = if is_scalar {
1239 ColumnarValue::Scalar(ScalarValue::try_from_array(values.as_ref(), 0)?)
1240 } else {
1241 ColumnarValue::Array(values)
1242 };
1243
1244 Ok(result)
1245 }
1246}
1247
1248impl PhysicalExpr for CaseExpr {
1249 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
1250 self.body.data_type(input_schema)
1251 }
1252
1253 fn nullable(&self, input_schema: &Schema) -> Result<bool> {
1254 let nullable_then = self
1255 .body
1256 .when_then_expr
1257 .iter()
1258 .filter_map(|(w, t)| {
1259 let is_nullable = match t.nullable(input_schema) {
1260 Err(e) => return Some(Err(e)),
1262 Ok(n) => n,
1263 };
1264
1265 if !is_nullable {
1268 return None;
1269 }
1270
1271 if self.body.expr.is_some() {
1273 return Some(Ok(()));
1274 }
1275
1276 let with_null = match replace_with_null(w, t.as_ref(), input_schema) {
1282 Err(e) => return Some(Err(e)),
1283 Ok(e) => e,
1284 };
1285
1286 let predicate_result = match evaluate_predicate(&with_null) {
1288 Err(e) => return Some(Err(e)),
1289 Ok(b) => b,
1290 };
1291
1292 match predicate_result {
1293 None | Some(true) => Some(Ok(())),
1295 Some(false) => None,
1298 }
1299 })
1300 .next();
1301
1302 if let Some(nullable_then) = nullable_then {
1303 nullable_then.map(|_| true)
1307 } else if let Some(e) = &self.body.else_expr {
1308 e.nullable(input_schema)
1311 } else {
1312 Ok(true)
1315 }
1316 }
1317
1318 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
1319 match &self.eval_method {
1320 EvalMethod::WithExpression(p) => {
1321 self.case_when_with_expr(batch, p)
1324 }
1325 EvalMethod::NoExpression(p) => {
1326 self.case_when_no_expr(batch, p)
1329 }
1330 EvalMethod::InfallibleExprOrNull => {
1331 self.case_column_or_null(batch)
1333 }
1334 EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch),
1335 EvalMethod::ExpressionOrExpression(p) => self.expr_or_expr(batch, p),
1336 EvalMethod::WithExprScalarLookupTable(lookup_table) => {
1337 self.with_lookup_table(batch, lookup_table)
1338 }
1339 }
1340 }
1341
1342 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
1343 let mut children = vec![];
1344 if let Some(expr) = &self.body.expr {
1345 children.push(expr)
1346 }
1347 self.body.when_then_expr.iter().for_each(|(cond, value)| {
1348 children.push(cond);
1349 children.push(value);
1350 });
1351
1352 if let Some(else_expr) = &self.body.else_expr {
1353 children.push(else_expr)
1354 }
1355 children
1356 }
1357
1358 fn with_new_children(
1360 self: Arc<Self>,
1361 children: Vec<Arc<dyn PhysicalExpr>>,
1362 ) -> Result<Arc<dyn PhysicalExpr>> {
1363 if children.len() != self.children().len() {
1364 internal_err!("CaseExpr: Wrong number of children")
1365 } else {
1366 let (expr, when_then_expr, else_expr) =
1367 match (self.expr().is_some(), self.body.else_expr.is_some()) {
1368 (true, true) => (
1369 Some(&children[0]),
1370 &children[1..children.len() - 1],
1371 Some(&children[children.len() - 1]),
1372 ),
1373 (true, false) => {
1374 (Some(&children[0]), &children[1..children.len()], None)
1375 }
1376 (false, true) => (
1377 None,
1378 &children[0..children.len() - 1],
1379 Some(&children[children.len() - 1]),
1380 ),
1381 (false, false) => (None, &children[0..children.len()], None),
1382 };
1383 Ok(Arc::new(CaseExpr::try_new(
1384 expr.cloned(),
1385 when_then_expr.iter().cloned().tuples().collect(),
1386 else_expr.cloned(),
1387 )?))
1388 }
1389 }
1390
1391 fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1392 write!(f, "CASE ")?;
1393 if let Some(e) = &self.body.expr {
1394 e.fmt_sql(f)?;
1395 write!(f, " ")?;
1396 }
1397
1398 for (w, t) in &self.body.when_then_expr {
1399 write!(f, "WHEN ")?;
1400 w.fmt_sql(f)?;
1401 write!(f, " THEN ")?;
1402 t.fmt_sql(f)?;
1403 write!(f, " ")?;
1404 }
1405
1406 if let Some(e) = &self.body.else_expr {
1407 write!(f, "ELSE ")?;
1408 e.fmt_sql(f)?;
1409 write!(f, " ")?;
1410 }
1411 write!(f, "END")
1412 }
1413}
1414
1415fn evaluate_predicate(predicate: &Arc<dyn PhysicalExpr>) -> Result<Option<bool>> {
1421 let batch = RecordBatch::try_new_with_options(
1423 Arc::new(Schema::empty()),
1424 vec![],
1425 &RecordBatchOptions::new().with_row_count(Some(1)),
1426 )?;
1427
1428 let result = match predicate.evaluate(&batch) {
1430 Err(_) => None,
1432 Ok(ColumnarValue::Array(array)) => Some(
1433 ScalarValue::try_from_array(array.as_ref(), 0)?
1434 .cast_to(&DataType::Boolean)?,
1435 ),
1436 Ok(ColumnarValue::Scalar(scalar)) => Some(scalar.cast_to(&DataType::Boolean)?),
1437 };
1438 Ok(result.map(|v| matches!(v, ScalarValue::Boolean(Some(true)))))
1439}
1440
1441fn replace_with_null(
1442 expr: &Arc<dyn PhysicalExpr>,
1443 expr_to_replace: &dyn PhysicalExpr,
1444 input_schema: &Schema,
1445) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
1446 let with_null = Arc::clone(expr)
1447 .transform_down(|e| {
1448 if e.as_ref().dyn_eq(expr_to_replace) {
1449 let data_type = e.data_type(input_schema)?;
1450 let null_literal = lit(ScalarValue::try_new_null(&data_type)?);
1451 Ok(Transformed::yes(null_literal))
1452 } else {
1453 Ok(Transformed::no(e))
1454 }
1455 })?
1456 .data;
1457 Ok(with_null)
1458}
1459
1460pub fn case(
1462 expr: Option<Arc<dyn PhysicalExpr>>,
1463 when_thens: Vec<WhenThen>,
1464 else_expr: Option<Arc<dyn PhysicalExpr>>,
1465) -> Result<Arc<dyn PhysicalExpr>> {
1466 Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?))
1467}
1468
1469#[cfg(test)]
1470mod tests {
1471 use super::*;
1472
1473 use crate::expressions;
1474 use crate::expressions::{BinaryExpr, binary, cast, col, is_not_null};
1475 use arrow::buffer::Buffer;
1476 use arrow::datatypes::DataType::Float64;
1477 use arrow::datatypes::Field;
1478 use datafusion_common::cast::{as_float64_array, as_int32_array};
1479 use datafusion_common::plan_err;
1480 use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
1481 use datafusion_expr::type_coercion::binary::type_union_coercion;
1482 use datafusion_expr_common::operator::Operator;
1483 use datafusion_physical_expr_common::physical_expr::fmt_sql;
1484 use half::f16;
1485
1486 #[test]
1487 fn case_with_expr() -> Result<()> {
1488 let batch = case_test_batch()?;
1489 let schema = batch.schema();
1490
1491 let when1 = lit("foo");
1493 let then1 = lit(123i32);
1494 let when2 = lit("bar");
1495 let then2 = lit(456i32);
1496
1497 let expr = generate_case_when_with_type_coercion(
1498 Some(col("a", &schema)?),
1499 vec![(when1, then1), (when2, then2)],
1500 None,
1501 schema.as_ref(),
1502 )?;
1503 let result = expr
1504 .evaluate(&batch)?
1505 .into_array(batch.num_rows())
1506 .expect("Failed to convert to array");
1507 let result = as_int32_array(&result)?;
1508
1509 let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1510
1511 assert_eq!(expected, result);
1512
1513 Ok(())
1514 }
1515
1516 #[test]
1517 fn case_with_expr_dictionary() -> Result<()> {
1518 let schema = Schema::new(vec![Field::new(
1519 "a",
1520 DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
1521 true,
1522 )]);
1523 let keys = UInt8Array::from(vec![0u8, 1u8, 2u8, 3u8]);
1524 let values = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
1525 let dictionary = DictionaryArray::new(keys, Arc::new(values));
1526 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?;
1527
1528 let schema = batch.schema();
1529
1530 let when1 = lit("foo");
1532 let then1 = lit(123i32);
1533 let when2 = lit("bar");
1534 let then2 = lit(456i32);
1535
1536 let expr = generate_case_when_with_type_coercion(
1537 Some(col("a", &schema)?),
1538 vec![(when1, then1), (when2, then2)],
1539 None,
1540 schema.as_ref(),
1541 )?;
1542 let result = expr
1543 .evaluate(&batch)?
1544 .into_array(batch.num_rows())
1545 .expect("Failed to convert to array");
1546 let result = as_int32_array(&result)?;
1547
1548 let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1549
1550 assert_eq!(expected, result);
1551
1552 Ok(())
1553 }
1554
1555 #[test]
1557 fn case_with_expr_primitive_dictionary() -> Result<()> {
1558 let schema = Schema::new(vec![Field::new(
1559 "a",
1560 DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::UInt64)),
1561 true,
1562 )]);
1563 let keys = UInt8Array::from(vec![0u8, 1u8, 2u8, 3u8]);
1564 let values = UInt64Array::from(vec![Some(10), Some(20), None, Some(30)]);
1565 let dictionary = DictionaryArray::new(keys, Arc::new(values));
1566 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?;
1567
1568 let schema = batch.schema();
1569
1570 let when1 = lit(10_u64);
1572 let then1 = lit(123_i32);
1573 let when2 = lit(30_u64);
1574 let then2 = lit(456_i32);
1575
1576 let expr = generate_case_when_with_type_coercion(
1577 Some(col("a", &schema)?),
1578 vec![(when1, then1), (when2, then2)],
1579 None,
1580 schema.as_ref(),
1581 )?;
1582 let result = expr
1583 .evaluate(&batch)?
1584 .into_array(batch.num_rows())
1585 .expect("Failed to convert to array");
1586 let result = as_int32_array(&result)?;
1587
1588 let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1589
1590 assert_eq!(expected, result);
1591
1592 Ok(())
1593 }
1594
1595 #[test]
1597 fn case_with_expr_boolean_dictionary() -> Result<()> {
1598 let schema = Schema::new(vec![Field::new(
1599 "a",
1600 DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Boolean)),
1601 true,
1602 )]);
1603 let keys = UInt8Array::from(vec![0u8, 1u8, 2u8, 3u8]);
1604 let values = BooleanArray::from(vec![Some(true), Some(false), None, Some(true)]);
1605 let dictionary = DictionaryArray::new(keys, Arc::new(values));
1606 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?;
1607
1608 let schema = batch.schema();
1609
1610 let when1 = lit(true);
1612 let then1 = lit(123i32);
1613 let when2 = lit(false);
1614 let then2 = lit(456i32);
1615
1616 let expr = generate_case_when_with_type_coercion(
1617 Some(col("a", &schema)?),
1618 vec![(when1, then1), (when2, then2)],
1619 None,
1620 schema.as_ref(),
1621 )?;
1622 let result = expr
1623 .evaluate(&batch)?
1624 .into_array(batch.num_rows())
1625 .expect("Failed to convert to array");
1626 let result = as_int32_array(&result)?;
1627
1628 let expected = &Int32Array::from(vec![Some(123), Some(456), None, Some(123)]);
1629
1630 assert_eq!(expected, result);
1631
1632 Ok(())
1633 }
1634
1635 #[test]
1636 fn case_with_expr_all_null_dictionary() -> Result<()> {
1637 let schema = Schema::new(vec![Field::new(
1638 "a",
1639 DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
1640 true,
1641 )]);
1642 let keys = UInt8Array::from(vec![2u8, 2u8, 2u8, 2u8]);
1643 let values = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
1644 let dictionary = DictionaryArray::new(keys, Arc::new(values));
1645 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?;
1646
1647 let schema = batch.schema();
1648
1649 let when1 = lit("foo");
1651 let then1 = lit(123i32);
1652 let when2 = lit("bar");
1653 let then2 = lit(456i32);
1654
1655 let expr = generate_case_when_with_type_coercion(
1656 Some(col("a", &schema)?),
1657 vec![(when1, then1), (when2, then2)],
1658 None,
1659 schema.as_ref(),
1660 )?;
1661 let result = expr
1662 .evaluate(&batch)?
1663 .into_array(batch.num_rows())
1664 .expect("Failed to convert to array");
1665 let result = as_int32_array(&result)?;
1666
1667 let expected = &Int32Array::from(vec![None, None, None, None]);
1668
1669 assert_eq!(expected, result);
1670
1671 Ok(())
1672 }
1673
1674 #[test]
1675 fn case_with_expr_else() -> Result<()> {
1676 let batch = case_test_batch()?;
1677 let schema = batch.schema();
1678
1679 let when1 = lit("foo");
1681 let then1 = lit(123i32);
1682 let when2 = lit("bar");
1683 let then2 = lit(456i32);
1684 let else_value = lit(999i32);
1685
1686 let expr = generate_case_when_with_type_coercion(
1687 Some(col("a", &schema)?),
1688 vec![(when1, then1), (when2, then2)],
1689 Some(else_value),
1690 schema.as_ref(),
1691 )?;
1692 let result = expr
1693 .evaluate(&batch)?
1694 .into_array(batch.num_rows())
1695 .expect("Failed to convert to array");
1696 let result = as_int32_array(&result)?;
1697
1698 let expected =
1699 &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
1700
1701 assert_eq!(expected, result);
1702
1703 Ok(())
1704 }
1705
1706 #[test]
1707 fn case_with_expr_divide_by_zero() -> Result<()> {
1708 let batch = case_test_batch1()?;
1709 let schema = batch.schema();
1710
1711 let when1 = lit(0i32);
1713 let then1 = lit(ScalarValue::Float64(None));
1714 let else_value = binary(
1715 lit(25.0f64),
1716 Operator::Divide,
1717 cast(col("a", &schema)?, &batch.schema(), Float64)?,
1718 &batch.schema(),
1719 )?;
1720
1721 let expr = generate_case_when_with_type_coercion(
1722 Some(col("a", &schema)?),
1723 vec![(when1, then1)],
1724 Some(else_value),
1725 schema.as_ref(),
1726 )?;
1727 let result = expr
1728 .evaluate(&batch)?
1729 .into_array(batch.num_rows())
1730 .expect("Failed to convert to array");
1731 let result =
1732 as_float64_array(&result).expect("failed to downcast to Float64Array");
1733
1734 let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
1735
1736 assert_eq!(expected, result);
1737
1738 Ok(())
1739 }
1740
1741 #[test]
1742 fn case_without_expr() -> Result<()> {
1743 let batch = case_test_batch()?;
1744 let schema = batch.schema();
1745
1746 let when1 = binary(
1748 col("a", &schema)?,
1749 Operator::Eq,
1750 lit("foo"),
1751 &batch.schema(),
1752 )?;
1753 let then1 = lit(123i32);
1754 let when2 = binary(
1755 col("a", &schema)?,
1756 Operator::Eq,
1757 lit("bar"),
1758 &batch.schema(),
1759 )?;
1760 let then2 = lit(456i32);
1761
1762 let expr = generate_case_when_with_type_coercion(
1763 None,
1764 vec![(when1, then1), (when2, then2)],
1765 None,
1766 schema.as_ref(),
1767 )?;
1768 let result = expr
1769 .evaluate(&batch)?
1770 .into_array(batch.num_rows())
1771 .expect("Failed to convert to array");
1772 let result = as_int32_array(&result)?;
1773
1774 let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1775
1776 assert_eq!(expected, result);
1777
1778 Ok(())
1779 }
1780
1781 #[test]
1782 fn case_with_expr_when_null() -> Result<()> {
1783 let batch = case_test_batch()?;
1784 let schema = batch.schema();
1785
1786 let when1 = lit(ScalarValue::Utf8(None));
1788 let then1 = lit(0i32);
1789 let when2 = col("a", &schema)?;
1790 let then2 = lit(123i32);
1791 let else_value = lit(999i32);
1792
1793 let expr = generate_case_when_with_type_coercion(
1794 Some(col("a", &schema)?),
1795 vec![(when1, then1), (when2, then2)],
1796 Some(else_value),
1797 schema.as_ref(),
1798 )?;
1799 let result = expr
1800 .evaluate(&batch)?
1801 .into_array(batch.num_rows())
1802 .expect("Failed to convert to array");
1803 let result = as_int32_array(&result)?;
1804
1805 let expected =
1806 &Int32Array::from(vec![Some(123), Some(123), Some(999), Some(123)]);
1807
1808 assert_eq!(expected, result);
1809
1810 Ok(())
1811 }
1812
1813 #[test]
1814 fn case_without_expr_divide_by_zero() -> Result<()> {
1815 let batch = case_test_batch1()?;
1816 let schema = batch.schema();
1817
1818 let when1 = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &batch.schema())?;
1820 let then1 = binary(
1821 lit(25.0f64),
1822 Operator::Divide,
1823 cast(col("a", &schema)?, &batch.schema(), Float64)?,
1824 &batch.schema(),
1825 )?;
1826 let x = lit(ScalarValue::Float64(None));
1827
1828 let expr = generate_case_when_with_type_coercion(
1829 None,
1830 vec![(when1, then1)],
1831 Some(x),
1832 schema.as_ref(),
1833 )?;
1834 let result = expr
1835 .evaluate(&batch)?
1836 .into_array(batch.num_rows())
1837 .expect("Failed to convert to array");
1838 let result =
1839 as_float64_array(&result).expect("failed to downcast to Float64Array");
1840
1841 let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
1842
1843 assert_eq!(expected, result);
1844
1845 Ok(())
1846 }
1847
1848 fn case_test_batch1() -> Result<RecordBatch> {
1849 let schema = Schema::new(vec![
1850 Field::new("a", DataType::Int32, true),
1851 Field::new("b", DataType::Int32, true),
1852 Field::new("c", DataType::Int32, true),
1853 ]);
1854 let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]);
1855 let b = Int32Array::from(vec![Some(3), None, Some(14), Some(7)]);
1856 let c = Int32Array::from(vec![Some(0), Some(-3), Some(777), None]);
1857 let batch = RecordBatch::try_new(
1858 Arc::new(schema),
1859 vec![Arc::new(a), Arc::new(b), Arc::new(c)],
1860 )?;
1861 Ok(batch)
1862 }
1863
1864 #[test]
1865 fn case_without_expr_else() -> Result<()> {
1866 let batch = case_test_batch()?;
1867 let schema = batch.schema();
1868
1869 let when1 = binary(
1871 col("a", &schema)?,
1872 Operator::Eq,
1873 lit("foo"),
1874 &batch.schema(),
1875 )?;
1876 let then1 = lit(123i32);
1877 let when2 = binary(
1878 col("a", &schema)?,
1879 Operator::Eq,
1880 lit("bar"),
1881 &batch.schema(),
1882 )?;
1883 let then2 = lit(456i32);
1884 let else_value = lit(999i32);
1885
1886 let expr = generate_case_when_with_type_coercion(
1887 None,
1888 vec![(when1, then1), (when2, then2)],
1889 Some(else_value),
1890 schema.as_ref(),
1891 )?;
1892 let result = expr
1893 .evaluate(&batch)?
1894 .into_array(batch.num_rows())
1895 .expect("Failed to convert to array");
1896 let result = as_int32_array(&result)?;
1897
1898 let expected =
1899 &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
1900
1901 assert_eq!(expected, result);
1902
1903 Ok(())
1904 }
1905
1906 #[test]
1907 fn case_with_type_cast() -> Result<()> {
1908 let batch = case_test_batch()?;
1909 let schema = batch.schema();
1910
1911 let when = binary(
1913 col("a", &schema)?,
1914 Operator::Eq,
1915 lit("foo"),
1916 &batch.schema(),
1917 )?;
1918 let then = lit(123.3f64);
1919 let else_value = lit(999i32);
1920
1921 let expr = generate_case_when_with_type_coercion(
1922 None,
1923 vec![(when, then)],
1924 Some(else_value),
1925 schema.as_ref(),
1926 )?;
1927 let result = expr
1928 .evaluate(&batch)?
1929 .into_array(batch.num_rows())
1930 .expect("Failed to convert to array");
1931 let result =
1932 as_float64_array(&result).expect("failed to downcast to Float64Array");
1933
1934 let expected =
1935 &Float64Array::from(vec![Some(123.3), Some(999.0), Some(999.0), Some(999.0)]);
1936
1937 assert_eq!(expected, result);
1938
1939 Ok(())
1940 }
1941
1942 #[test]
1943 fn case_with_matches_and_nulls() -> Result<()> {
1944 let batch = case_test_batch_nulls()?;
1945 let schema = batch.schema();
1946
1947 let when = binary(
1949 col("load4", &schema)?,
1950 Operator::Eq,
1951 lit(1.77f64),
1952 &batch.schema(),
1953 )?;
1954 let then = col("load4", &schema)?;
1955
1956 let expr = generate_case_when_with_type_coercion(
1957 None,
1958 vec![(when, then)],
1959 None,
1960 schema.as_ref(),
1961 )?;
1962 let result = expr
1963 .evaluate(&batch)?
1964 .into_array(batch.num_rows())
1965 .expect("Failed to convert to array");
1966 let result =
1967 as_float64_array(&result).expect("failed to downcast to Float64Array");
1968
1969 let expected =
1970 &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
1971
1972 assert_eq!(expected, result);
1973
1974 Ok(())
1975 }
1976
1977 #[test]
1978 fn case_with_scalar_predicate() -> Result<()> {
1979 let batch = case_test_batch_nulls()?;
1980 let schema = batch.schema();
1981
1982 let when = lit(true);
1984 let then = col("load4", &schema)?;
1985 let expr = generate_case_when_with_type_coercion(
1986 None,
1987 vec![(when, then)],
1988 None,
1989 schema.as_ref(),
1990 )?;
1991
1992 let result = expr
1994 .evaluate(&batch)?
1995 .into_array(batch.num_rows())
1996 .expect("Failed to convert to array");
1997 let result =
1998 as_float64_array(&result).expect("failed to downcast to Float64Array");
1999 let expected = &Float64Array::from(vec![
2000 Some(1.77),
2001 None,
2002 None,
2003 Some(1.78),
2004 None,
2005 Some(1.77),
2006 ]);
2007 assert_eq!(expected, result);
2008
2009 let expected = Float64Array::from(vec![Some(1.1)]);
2011 let batch =
2012 RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(expected.clone())])?;
2013 let result = expr
2014 .evaluate(&batch)?
2015 .into_array(batch.num_rows())
2016 .expect("Failed to convert to array");
2017 let result =
2018 as_float64_array(&result).expect("failed to downcast to Float64Array");
2019 assert_eq!(&expected, result);
2020
2021 Ok(())
2022 }
2023
2024 #[test]
2025 fn case_expr_matches_and_nulls() -> Result<()> {
2026 let batch = case_test_batch_nulls()?;
2027 let schema = batch.schema();
2028
2029 let expr = col("load4", &schema)?;
2031 let when = lit(1.77f64);
2032 let then = col("load4", &schema)?;
2033
2034 let expr = generate_case_when_with_type_coercion(
2035 Some(expr),
2036 vec![(when, then)],
2037 None,
2038 schema.as_ref(),
2039 )?;
2040 let result = expr
2041 .evaluate(&batch)?
2042 .into_array(batch.num_rows())
2043 .expect("Failed to convert to array");
2044 let result =
2045 as_float64_array(&result).expect("failed to downcast to Float64Array");
2046
2047 let expected =
2048 &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
2049
2050 assert_eq!(expected, result);
2051
2052 Ok(())
2053 }
2054
2055 #[test]
2056 fn test_when_null_and_some_cond_else_null() -> Result<()> {
2057 let batch = case_test_batch()?;
2058 let schema = batch.schema();
2059
2060 let when = binary(
2061 Arc::new(Literal::new(ScalarValue::Boolean(None))),
2062 Operator::And,
2063 binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?,
2064 &schema,
2065 )?;
2066 let then = col("a", &schema)?;
2067
2068 let expr = Arc::new(CaseExpr::try_new(None, vec![(when, then)], None)?);
2070 let result = expr
2071 .evaluate(&batch)?
2072 .into_array(batch.num_rows())
2073 .expect("Failed to convert to array");
2074 let result = as_string_array(&result);
2075
2076 assert_eq!(result.logical_null_count(), batch.num_rows());
2078 Ok(())
2079 }
2080
2081 fn case_test_batch() -> Result<RecordBatch> {
2082 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
2083 let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
2084 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
2085 Ok(batch)
2086 }
2087
2088 fn case_test_batch_nulls() -> Result<RecordBatch> {
2091 let load4: Float64Array = vec![
2092 Some(1.77), Some(1.77), Some(1.77), Some(1.78), None, Some(1.77), ]
2099 .into_iter()
2100 .collect();
2101
2102 let null_buffer = Buffer::from([0b00101001u8]);
2103 let load4 = load4
2104 .into_data()
2105 .into_builder()
2106 .null_bit_buffer(Some(null_buffer))
2107 .build()
2108 .unwrap();
2109 let load4: Float64Array = load4.into();
2110
2111 let batch =
2112 RecordBatch::try_from_iter(vec![("load4", Arc::new(load4) as ArrayRef)])?;
2113 Ok(batch)
2114 }
2115
2116 #[test]
2117 fn case_test_incompatible() -> Result<()> {
2118 let batch = case_test_batch()?;
2121 let schema = batch.schema();
2122
2123 let when1 = binary(
2125 col("a", &schema)?,
2126 Operator::Eq,
2127 lit("foo"),
2128 &batch.schema(),
2129 )?;
2130 let then1 = lit(123i32);
2131 let when2 = binary(
2132 col("a", &schema)?,
2133 Operator::Eq,
2134 lit("bar"),
2135 &batch.schema(),
2136 )?;
2137 let then2 = lit(true);
2138
2139 let expr = generate_case_when_with_type_coercion(
2140 None,
2141 vec![(when1, then1), (when2, then2)],
2142 None,
2143 schema.as_ref(),
2144 );
2145 assert!(expr.is_err());
2146
2147 let when1 = binary(
2152 col("a", &schema)?,
2153 Operator::Eq,
2154 lit("foo"),
2155 &batch.schema(),
2156 )?;
2157 let then1 = lit(123i32);
2158 let when2 = binary(
2159 col("a", &schema)?,
2160 Operator::Eq,
2161 lit("bar"),
2162 &batch.schema(),
2163 )?;
2164 let then2 = lit(456i64);
2165 let else_expr = lit(1.23f64);
2166
2167 let expr = generate_case_when_with_type_coercion(
2168 None,
2169 vec![(when1, then1), (when2, then2)],
2170 Some(else_expr),
2171 schema.as_ref(),
2172 );
2173 assert!(expr.is_ok());
2174 let result_type = expr.unwrap().data_type(schema.as_ref())?;
2175 assert_eq!(Float64, result_type);
2176 Ok(())
2177 }
2178
2179 #[test]
2180 fn case_eq() -> Result<()> {
2181 let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
2182
2183 let when1 = lit("foo");
2184 let then1 = lit(123i32);
2185 let when2 = lit("bar");
2186 let then2 = lit(456i32);
2187 let else_value = lit(999i32);
2188
2189 let expr1 = generate_case_when_with_type_coercion(
2190 Some(col("a", &schema)?),
2191 vec![
2192 (Arc::clone(&when1), Arc::clone(&then1)),
2193 (Arc::clone(&when2), Arc::clone(&then2)),
2194 ],
2195 Some(Arc::clone(&else_value)),
2196 &schema,
2197 )?;
2198
2199 let expr2 = generate_case_when_with_type_coercion(
2200 Some(col("a", &schema)?),
2201 vec![
2202 (Arc::clone(&when1), Arc::clone(&then1)),
2203 (Arc::clone(&when2), Arc::clone(&then2)),
2204 ],
2205 Some(Arc::clone(&else_value)),
2206 &schema,
2207 )?;
2208
2209 let expr3 = generate_case_when_with_type_coercion(
2210 Some(col("a", &schema)?),
2211 vec![(Arc::clone(&when1), Arc::clone(&then1)), (when2, then2)],
2212 None,
2213 &schema,
2214 )?;
2215
2216 let expr4 = generate_case_when_with_type_coercion(
2217 Some(col("a", &schema)?),
2218 vec![(when1, then1)],
2219 Some(else_value),
2220 &schema,
2221 )?;
2222
2223 assert!(expr1.eq(&expr2));
2224 assert!(expr2.eq(&expr1));
2225
2226 assert!(expr2.ne(&expr3));
2227 assert!(expr3.ne(&expr2));
2228
2229 assert!(expr1.ne(&expr4));
2230 assert!(expr4.ne(&expr1));
2231
2232 Ok(())
2233 }
2234
2235 #[test]
2236 fn case_transform() -> Result<()> {
2237 let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
2238
2239 let when1 = lit("foo");
2240 let then1 = lit(123i32);
2241 let when2 = lit("bar");
2242 let then2 = lit(456i32);
2243 let else_value = lit(999i32);
2244
2245 let expr = generate_case_when_with_type_coercion(
2246 Some(col("a", &schema)?),
2247 vec![
2248 (Arc::clone(&when1), Arc::clone(&then1)),
2249 (Arc::clone(&when2), Arc::clone(&then2)),
2250 ],
2251 Some(Arc::clone(&else_value)),
2252 &schema,
2253 )?;
2254
2255 let expr2 = Arc::clone(&expr)
2256 .transform(|e| {
2257 let transformed = match e.downcast_ref::<Literal>() {
2258 Some(lit_value) => match lit_value.value() {
2259 ScalarValue::Utf8(Some(str_value)) => {
2260 Some(lit(str_value.to_uppercase()))
2261 }
2262 _ => None,
2263 },
2264 _ => None,
2265 };
2266 Ok(if let Some(transformed) = transformed {
2267 Transformed::yes(transformed)
2268 } else {
2269 Transformed::no(e)
2270 })
2271 })
2272 .data()
2273 .unwrap();
2274
2275 let expr3 = Arc::clone(&expr)
2276 .transform_down(|e| {
2277 let transformed = match e.downcast_ref::<Literal>() {
2278 Some(lit_value) => match lit_value.value() {
2279 ScalarValue::Utf8(Some(str_value)) => {
2280 Some(lit(str_value.to_uppercase()))
2281 }
2282 _ => None,
2283 },
2284 _ => None,
2285 };
2286 Ok(if let Some(transformed) = transformed {
2287 Transformed::yes(transformed)
2288 } else {
2289 Transformed::no(e)
2290 })
2291 })
2292 .data()
2293 .unwrap();
2294
2295 assert!(expr.ne(&expr2));
2296 assert!(expr2.eq(&expr3));
2297
2298 Ok(())
2299 }
2300
2301 #[test]
2302 fn test_column_or_null_specialization() -> Result<()> {
2303 let mut c1 = Int32Builder::new();
2305 let mut c2 = StringBuilder::new();
2306 for i in 0..1000 {
2307 c1.append_value(i);
2308 if i % 7 == 0 {
2309 c2.append_null();
2310 } else {
2311 c2.append_value(format!("string {i}"));
2312 }
2313 }
2314 let c1 = Arc::new(c1.finish());
2315 let c2 = Arc::new(c2.finish());
2316 let schema = Schema::new(vec![
2317 Field::new("c1", DataType::Int32, true),
2318 Field::new("c2", DataType::Utf8, true),
2319 ]);
2320 let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap();
2321
2322 let predicate = Arc::new(BinaryExpr::new(
2324 make_col("c1", 0),
2325 Operator::LtEq,
2326 make_lit_i32(250),
2327 ));
2328 let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?;
2329 assert_eq!(expr.eval_method, EvalMethod::InfallibleExprOrNull);
2330 match expr.evaluate(&batch)? {
2331 ColumnarValue::Array(array) => {
2332 assert_eq!(1000, array.len());
2333 assert_eq!(785, array.null_count());
2334 }
2335 _ => unreachable!(),
2336 }
2337 Ok(())
2338 }
2339
2340 #[test]
2341 fn test_expr_or_expr_specialization() -> Result<()> {
2342 let batch = case_test_batch1()?;
2343 let schema = batch.schema();
2344 let when = binary(
2345 col("a", &schema)?,
2346 Operator::LtEq,
2347 lit(2i32),
2348 &batch.schema(),
2349 )?;
2350 let then = col("b", &schema)?;
2351 let else_expr = col("c", &schema)?;
2352 let expr = CaseExpr::try_new(None, vec![(when, then)], Some(else_expr))?;
2353 assert!(matches!(
2354 expr.eval_method,
2355 EvalMethod::ExpressionOrExpression(_)
2356 ));
2357 let result = expr
2358 .evaluate(&batch)?
2359 .into_array(batch.num_rows())
2360 .expect("Failed to convert to array");
2361 let result = as_int32_array(&result).expect("failed to downcast to Int32Array");
2362
2363 let expected = &Int32Array::from(vec![Some(3), None, Some(777), None]);
2364
2365 assert_eq!(expected, result);
2366 Ok(())
2367 }
2368
2369 fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
2370 Arc::new(Column::new(name, index))
2371 }
2372
2373 fn make_lit_i32(n: i32) -> Arc<dyn PhysicalExpr> {
2374 Arc::new(Literal::new(ScalarValue::Int32(Some(n))))
2375 }
2376
2377 fn generate_case_when_with_type_coercion(
2378 expr: Option<Arc<dyn PhysicalExpr>>,
2379 when_thens: Vec<WhenThen>,
2380 else_expr: Option<Arc<dyn PhysicalExpr>>,
2381 input_schema: &Schema,
2382 ) -> Result<Arc<dyn PhysicalExpr>> {
2383 let coerce_type =
2384 get_case_common_type(&when_thens, else_expr.clone(), input_schema);
2385 let (when_thens, else_expr) = match coerce_type {
2386 None => plan_err!(
2387 "Can't get a common type for then {when_thens:?} and else {else_expr:?} expression"
2388 ),
2389 Some(data_type) => {
2390 let left = when_thens
2392 .into_iter()
2393 .map(|(when, then)| {
2394 let then = try_cast(then, input_schema, data_type.clone())?;
2395 Ok((when, then))
2396 })
2397 .collect::<Result<Vec<_>>>()?;
2398 let right = match else_expr {
2399 None => None,
2400 Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?),
2401 };
2402
2403 Ok((left, right))
2404 }
2405 }?;
2406 case(expr, when_thens, else_expr)
2407 }
2408
2409 fn get_case_common_type(
2410 when_thens: &[WhenThen],
2411 else_expr: Option<Arc<dyn PhysicalExpr>>,
2412 input_schema: &Schema,
2413 ) -> Option<DataType> {
2414 let thens_type = when_thens
2415 .iter()
2416 .map(|when_then| {
2417 let data_type = &when_then.1.data_type(input_schema).unwrap();
2418 data_type.clone()
2419 })
2420 .collect::<Vec<_>>();
2421 let else_type = match else_expr {
2422 None => {
2423 thens_type[0].clone()
2425 }
2426 Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(),
2427 };
2428 thens_type
2429 .iter()
2430 .try_fold(else_type, |left_type, right_type| {
2431 type_union_coercion(&left_type, right_type)
2432 })
2433 }
2434
2435 #[test]
2436 fn test_fmt_sql() -> Result<()> {
2437 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
2438
2439 let when = binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?;
2441 let then = lit(123.3f64);
2442 let else_value = lit(999i32);
2443
2444 let expr = generate_case_when_with_type_coercion(
2445 None,
2446 vec![(when, then)],
2447 Some(else_value),
2448 &schema,
2449 )?;
2450
2451 let display_string = expr.to_string();
2452 assert_eq!(
2453 display_string,
2454 "CASE WHEN a@0 = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END"
2455 );
2456
2457 let sql_string = fmt_sql(expr.as_ref()).to_string();
2458 assert_eq!(
2459 sql_string,
2460 "CASE WHEN a = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END"
2461 );
2462
2463 Ok(())
2464 }
2465
2466 fn when_then_else(
2467 when: &Arc<dyn PhysicalExpr>,
2468 then: &Arc<dyn PhysicalExpr>,
2469 els: &Arc<dyn PhysicalExpr>,
2470 ) -> Result<Arc<dyn PhysicalExpr>> {
2471 let case = CaseExpr::try_new(
2472 None,
2473 vec![(Arc::clone(when), Arc::clone(then))],
2474 Some(Arc::clone(els)),
2475 )?;
2476 Ok(Arc::new(case))
2477 }
2478
2479 #[test]
2480 fn test_case_expression_nullability_with_nullable_column() -> Result<()> {
2481 case_expression_nullability(true)
2482 }
2483
2484 #[test]
2485 fn test_case_expression_nullability_with_not_nullable_column() -> Result<()> {
2486 case_expression_nullability(false)
2487 }
2488
2489 fn case_expression_nullability(col_is_nullable: bool) -> Result<()> {
2490 let schema =
2491 Schema::new(vec![Field::new("foo", DataType::Int32, col_is_nullable)]);
2492
2493 let foo = col("foo", &schema)?;
2494 let foo_is_not_null = is_not_null(Arc::clone(&foo))?;
2495 let foo_is_null = expressions::is_null(Arc::clone(&foo))?;
2496 let not_foo_is_null = expressions::not(Arc::clone(&foo_is_null))?;
2497 let zero = lit(0);
2498 let foo_eq_zero =
2499 binary(Arc::clone(&foo), Operator::Eq, Arc::clone(&zero), &schema)?;
2500
2501 assert_not_nullable(when_then_else(&foo_is_not_null, &foo, &zero)?, &schema);
2502 assert_not_nullable(when_then_else(¬_foo_is_null, &foo, &zero)?, &schema);
2503 assert_not_nullable(when_then_else(&foo_eq_zero, &foo, &zero)?, &schema);
2504
2505 assert_not_nullable(
2506 when_then_else(
2507 &binary(
2508 Arc::clone(&foo_is_not_null),
2509 Operator::And,
2510 Arc::clone(&foo_eq_zero),
2511 &schema,
2512 )?,
2513 &foo,
2514 &zero,
2515 )?,
2516 &schema,
2517 );
2518
2519 assert_not_nullable(
2520 when_then_else(
2521 &binary(
2522 Arc::clone(&foo_eq_zero),
2523 Operator::And,
2524 Arc::clone(&foo_is_not_null),
2525 &schema,
2526 )?,
2527 &foo,
2528 &zero,
2529 )?,
2530 &schema,
2531 );
2532
2533 assert_not_nullable(
2534 when_then_else(
2535 &binary(
2536 Arc::clone(&foo_is_not_null),
2537 Operator::Or,
2538 Arc::clone(&foo_eq_zero),
2539 &schema,
2540 )?,
2541 &foo,
2542 &zero,
2543 )?,
2544 &schema,
2545 );
2546
2547 assert_not_nullable(
2548 when_then_else(
2549 &binary(
2550 Arc::clone(&foo_eq_zero),
2551 Operator::Or,
2552 Arc::clone(&foo_is_not_null),
2553 &schema,
2554 )?,
2555 &foo,
2556 &zero,
2557 )?,
2558 &schema,
2559 );
2560
2561 assert_nullability(
2562 when_then_else(
2563 &binary(
2564 Arc::clone(&foo_is_null),
2565 Operator::Or,
2566 Arc::clone(&foo_eq_zero),
2567 &schema,
2568 )?,
2569 &foo,
2570 &zero,
2571 )?,
2572 &schema,
2573 col_is_nullable,
2574 );
2575
2576 assert_nullability(
2577 when_then_else(
2578 &binary(
2579 binary(Arc::clone(&foo), Operator::Eq, Arc::clone(&zero), &schema)?,
2580 Operator::Or,
2581 Arc::clone(&foo_is_null),
2582 &schema,
2583 )?,
2584 &foo,
2585 &zero,
2586 )?,
2587 &schema,
2588 col_is_nullable,
2589 );
2590
2591 assert_not_nullable(
2592 when_then_else(
2593 &binary(
2594 binary(
2595 binary(
2596 Arc::clone(&foo),
2597 Operator::Eq,
2598 Arc::clone(&zero),
2599 &schema,
2600 )?,
2601 Operator::And,
2602 Arc::clone(&foo_is_not_null),
2603 &schema,
2604 )?,
2605 Operator::Or,
2606 binary(
2607 binary(
2608 Arc::clone(&foo),
2609 Operator::Eq,
2610 Arc::clone(&foo),
2611 &schema,
2612 )?,
2613 Operator::And,
2614 Arc::clone(&foo_is_not_null),
2615 &schema,
2616 )?,
2617 &schema,
2618 )?,
2619 &foo,
2620 &zero,
2621 )?,
2622 &schema,
2623 );
2624
2625 Ok(())
2626 }
2627
2628 fn assert_not_nullable(expr: Arc<dyn PhysicalExpr>, schema: &Schema) {
2629 assert!(!expr.nullable(schema).unwrap());
2630 }
2631
2632 fn assert_nullable(expr: Arc<dyn PhysicalExpr>, schema: &Schema) {
2633 assert!(expr.nullable(schema).unwrap());
2634 }
2635
2636 fn assert_nullability(expr: Arc<dyn PhysicalExpr>, schema: &Schema, nullable: bool) {
2637 if nullable {
2638 assert_nullable(expr, schema);
2639 } else {
2640 assert_not_nullable(expr, schema);
2641 }
2642 }
2643
2644 fn test_case_when_literal_lookup(
2647 values: ArrayRef,
2648 lookup_map: &[(ScalarValue, ScalarValue)],
2649 else_value: Option<ScalarValue>,
2650 expected: ArrayRef,
2651 ) {
2652 let schema = Schema::new(vec![Field::new(
2659 "a",
2660 values.data_type().clone(),
2661 values.is_nullable(),
2662 )]);
2663 let schema = Arc::new(schema);
2664
2665 let batch = RecordBatch::try_new(schema, vec![values])
2666 .expect("failed to create RecordBatch");
2667
2668 let schema = batch.schema_ref();
2669 let case = col("a", schema).expect("failed to create col");
2670
2671 let when_then = lookup_map
2672 .iter()
2673 .map(|(when, then)| {
2674 (
2675 Arc::new(Literal::new(when.clone())) as _,
2676 Arc::new(Literal::new(then.clone())) as _,
2677 )
2678 })
2679 .collect::<Vec<WhenThen>>();
2680
2681 let else_expr = else_value.map(|else_value| {
2682 Arc::new(Literal::new(else_value)) as Arc<dyn PhysicalExpr>
2683 });
2684 let expr = CaseExpr::try_new(Some(case), when_then, else_expr)
2685 .expect("failed to create case");
2686
2687 assert!(
2689 matches!(
2690 expr.eval_method,
2691 EvalMethod::WithExprScalarLookupTable { .. }
2692 ),
2693 "we should use the expected eval method"
2694 );
2695
2696 let actual = expr
2697 .evaluate(&batch)
2698 .expect("failed to evaluate case")
2699 .into_array(batch.num_rows())
2700 .expect("Failed to convert to array");
2701
2702 assert_eq!(
2703 actual.data_type(),
2704 expected.data_type(),
2705 "Data type mismatch"
2706 );
2707
2708 assert_eq!(
2709 actual.as_ref(),
2710 expected.as_ref(),
2711 "actual (left) does not match expected (right)"
2712 );
2713 }
2714
2715 fn create_lookup<When, Then>(
2716 when_then_pairs: impl IntoIterator<Item = (When, Then)>,
2717 ) -> Vec<(ScalarValue, ScalarValue)>
2718 where
2719 ScalarValue: From<When>,
2720 ScalarValue: From<Then>,
2721 {
2722 when_then_pairs
2723 .into_iter()
2724 .map(|(when, then)| (ScalarValue::from(when), ScalarValue::from(then)))
2725 .collect()
2726 }
2727
2728 fn create_input_and_expected<Input, Expected, InputFromItem, ExpectedFromItem>(
2729 input_and_expected_pairs: impl IntoIterator<Item = (InputFromItem, ExpectedFromItem)>,
2730 ) -> (Input, Expected)
2731 where
2732 Input: Array + From<Vec<InputFromItem>>,
2733 Expected: Array + From<Vec<ExpectedFromItem>>,
2734 {
2735 let (input_items, expected_items): (Vec<InputFromItem>, Vec<ExpectedFromItem>) =
2736 input_and_expected_pairs.into_iter().unzip();
2737
2738 (Input::from(input_items), Expected::from(expected_items))
2739 }
2740
2741 fn test_lookup_eval_with_and_without_else(
2742 lookup_map: &[(ScalarValue, ScalarValue)],
2743 input_values: ArrayRef,
2744 expected: StringArray,
2745 ) {
2746 test_case_when_literal_lookup(
2748 Arc::clone(&input_values),
2749 lookup_map,
2750 None,
2751 Arc::new(expected.clone()),
2752 );
2753
2754 let else_value = "___fallback___";
2756
2757 let expected_with_else = expected
2759 .iter()
2760 .map(|item| item.unwrap_or(else_value))
2761 .map(Some)
2762 .collect::<StringArray>();
2763
2764 test_case_when_literal_lookup(
2766 input_values,
2767 lookup_map,
2768 Some(ScalarValue::Utf8(Some(else_value.to_string()))),
2769 Arc::new(expected_with_else),
2770 );
2771 }
2772
2773 #[test]
2774 fn test_case_when_literal_lookup_int32_to_string() {
2775 let lookup_map = create_lookup([
2776 (Some(4), Some("four")),
2777 (Some(2), Some("two")),
2778 (Some(3), Some("three")),
2779 (Some(1), Some("one")),
2780 ]);
2781
2782 let (input_values, expected) =
2783 create_input_and_expected::<Int32Array, StringArray, _, _>([
2784 (1, Some("one")),
2785 (2, Some("two")),
2786 (3, Some("three")),
2787 (3, Some("three")),
2788 (2, Some("two")),
2789 (3, Some("three")),
2790 (5, None), (5, None), (3, Some("three")),
2793 (5, None), ]);
2795
2796 test_lookup_eval_with_and_without_else(
2797 &lookup_map,
2798 Arc::new(input_values),
2799 expected,
2800 );
2801 }
2802
2803 #[test]
2804 fn test_case_when_literal_lookup_none_case_should_never_match() {
2805 let lookup_map = create_lookup([
2806 (Some(4), Some("four")),
2807 (None, Some("none")),
2808 (Some(2), Some("two")),
2809 (Some(1), Some("one")),
2810 ]);
2811
2812 let (input_values, expected) =
2813 create_input_and_expected::<Int32Array, StringArray, _, _>([
2814 (Some(1), Some("one")),
2815 (Some(5), None), (None, None), (Some(2), Some("two")),
2818 (None, None), (None, None), (Some(2), Some("two")),
2821 (Some(5), None), ]);
2823
2824 test_lookup_eval_with_and_without_else(
2825 &lookup_map,
2826 Arc::new(input_values),
2827 expected,
2828 );
2829 }
2830
2831 #[test]
2832 fn test_case_when_literal_lookup_int32_to_string_with_duplicate_cases() {
2833 let lookup_map = create_lookup([
2834 (Some(4), Some("four")),
2835 (Some(4), Some("no 4")),
2836 (Some(2), Some("two")),
2837 (Some(2), Some("no 2")),
2838 (Some(3), Some("three")),
2839 (Some(3), Some("no 3")),
2840 (Some(2), Some("no 2")),
2841 (Some(4), Some("no 4")),
2842 (Some(2), Some("no 2")),
2843 (Some(3), Some("no 3")),
2844 (Some(4), Some("no 4")),
2845 (Some(2), Some("no 2")),
2846 (Some(3), Some("no 3")),
2847 (Some(3), Some("no 3")),
2848 ]);
2849
2850 let (input_values, expected) =
2851 create_input_and_expected::<Int32Array, StringArray, _, _>([
2852 (1, None), (2, Some("two")),
2854 (3, Some("three")),
2855 (3, Some("three")),
2856 (2, Some("two")),
2857 (3, Some("three")),
2858 (5, None), (5, None), (3, Some("three")),
2861 (5, None), ]);
2863
2864 test_lookup_eval_with_and_without_else(
2865 &lookup_map,
2866 Arc::new(input_values),
2867 expected,
2868 );
2869 }
2870
2871 #[test]
2872 fn test_case_when_literal_lookup_f32_to_string_with_special_values_and_duplicate_cases()
2873 {
2874 let lookup_map = create_lookup([
2875 (Some(4.0), Some("four point zero")),
2876 (Some(f32::NAN), Some("NaN")),
2877 (Some(3.2), Some("three point two")),
2878 (Some(f32::NAN), Some("should not use this NaN branch")),
2880 (Some(f32::INFINITY), Some("Infinity")),
2881 (Some(0.0), Some("zero")),
2882 (
2884 Some(f32::INFINITY),
2885 Some("should not use this Infinity branch"),
2886 ),
2887 (Some(1.1), Some("one point one")),
2888 ]);
2889
2890 let (input_values, expected) =
2891 create_input_and_expected::<Float32Array, StringArray, _, _>([
2892 (1.1, Some("one point one")),
2893 (f32::NAN, Some("NaN")),
2894 (3.2, Some("three point two")),
2895 (3.2, Some("three point two")),
2896 (0.0, Some("zero")),
2897 (f32::INFINITY, Some("Infinity")),
2898 (3.2, Some("three point two")),
2899 (f32::NEG_INFINITY, None), (f32::NEG_INFINITY, None), (3.2, Some("three point two")),
2902 (-0.0, None), ]);
2904
2905 test_lookup_eval_with_and_without_else(
2906 &lookup_map,
2907 Arc::new(input_values),
2908 expected,
2909 );
2910 }
2911
2912 #[test]
2913 fn test_case_when_literal_lookup_f16_to_string_with_special_values() {
2914 let lookup_map = create_lookup([
2915 (
2916 ScalarValue::Float16(Some(f16::from_f32(3.2))),
2917 Some("3 dot 2"),
2918 ),
2919 (ScalarValue::Float16(Some(f16::NAN)), Some("NaN")),
2920 (
2921 ScalarValue::Float16(Some(f16::from_f32(17.4))),
2922 Some("17 dot 4"),
2923 ),
2924 (ScalarValue::Float16(Some(f16::INFINITY)), Some("Infinity")),
2925 (ScalarValue::Float16(Some(f16::ZERO)), Some("zero")),
2926 ]);
2927
2928 let (input_values, expected) =
2929 create_input_and_expected::<Float16Array, StringArray, _, _>([
2930 (f16::from_f32(3.2), Some("3 dot 2")),
2931 (f16::NAN, Some("NaN")),
2932 (f16::from_f32(17.4), Some("17 dot 4")),
2933 (f16::from_f32(17.4), Some("17 dot 4")),
2934 (f16::INFINITY, Some("Infinity")),
2935 (f16::from_f32(17.4), Some("17 dot 4")),
2936 (f16::NEG_INFINITY, None), (f16::NEG_INFINITY, None), (f16::from_f32(17.4), Some("17 dot 4")),
2939 (f16::NEG_ZERO, None), ]);
2941
2942 test_lookup_eval_with_and_without_else(
2943 &lookup_map,
2944 Arc::new(input_values),
2945 expected,
2946 );
2947 }
2948
2949 #[test]
2950 fn test_case_when_literal_lookup_f32_to_string_with_special_values() {
2951 let lookup_map = create_lookup([
2952 (3.2, Some("3 dot 2")),
2953 (f32::NAN, Some("NaN")),
2954 (17.4, Some("17 dot 4")),
2955 (f32::INFINITY, Some("Infinity")),
2956 (f32::ZERO, Some("zero")),
2957 ]);
2958
2959 let (input_values, expected) =
2960 create_input_and_expected::<Float32Array, StringArray, _, _>([
2961 (3.2, Some("3 dot 2")),
2962 (f32::NAN, Some("NaN")),
2963 (17.4, Some("17 dot 4")),
2964 (17.4, Some("17 dot 4")),
2965 (f32::INFINITY, Some("Infinity")),
2966 (17.4, Some("17 dot 4")),
2967 (f32::NEG_INFINITY, None), (f32::NEG_INFINITY, None), (17.4, Some("17 dot 4")),
2970 (-0.0, None), ]);
2972
2973 test_lookup_eval_with_and_without_else(
2974 &lookup_map,
2975 Arc::new(input_values),
2976 expected,
2977 );
2978 }
2979
2980 #[test]
2981 fn test_case_when_literal_lookup_f64_to_string_with_special_values() {
2982 let lookup_map = create_lookup([
2983 (3.2, Some("3 dot 2")),
2984 (f64::NAN, Some("NaN")),
2985 (17.4, Some("17 dot 4")),
2986 (f64::INFINITY, Some("Infinity")),
2987 (f64::ZERO, Some("zero")),
2988 ]);
2989
2990 let (input_values, expected) =
2991 create_input_and_expected::<Float64Array, StringArray, _, _>([
2992 (3.2, Some("3 dot 2")),
2993 (f64::NAN, Some("NaN")),
2994 (17.4, Some("17 dot 4")),
2995 (17.4, Some("17 dot 4")),
2996 (f64::INFINITY, Some("Infinity")),
2997 (17.4, Some("17 dot 4")),
2998 (f64::NEG_INFINITY, None), (f64::NEG_INFINITY, None), (17.4, Some("17 dot 4")),
3001 (-0.0, None), ]);
3003
3004 test_lookup_eval_with_and_without_else(
3005 &lookup_map,
3006 Arc::new(input_values),
3007 expected,
3008 );
3009 }
3010
3011 #[test]
3013 fn test_decimal_with_non_default_precision_and_scale() {
3014 let lookup_map = create_lookup([
3015 (ScalarValue::Decimal32(Some(4), 3, 2), Some("four")),
3016 (ScalarValue::Decimal32(Some(2), 3, 2), Some("two")),
3017 (ScalarValue::Decimal32(Some(3), 3, 2), Some("three")),
3018 (ScalarValue::Decimal32(Some(1), 3, 2), Some("one")),
3019 ]);
3020
3021 let (input_values, expected) =
3022 create_input_and_expected::<Decimal32Array, StringArray, _, _>([
3023 (1, Some("one")),
3024 (2, Some("two")),
3025 (3, Some("three")),
3026 (3, Some("three")),
3027 (2, Some("two")),
3028 (3, Some("three")),
3029 (5, None), (5, None), (3, Some("three")),
3032 (5, None), ]);
3034
3035 let input_values = input_values
3036 .with_precision_and_scale(3, 2)
3037 .expect("must be able to set precision and scale");
3038
3039 test_lookup_eval_with_and_without_else(
3040 &lookup_map,
3041 Arc::new(input_values),
3042 expected,
3043 );
3044 }
3045
3046 #[test]
3048 fn test_timestamp_with_non_default_timezone() {
3049 let timezone: Option<Arc<str>> = Some("-10:00".into());
3050 let lookup_map = create_lookup([
3051 (
3052 ScalarValue::TimestampMillisecond(Some(4), timezone.clone()),
3053 Some("four"),
3054 ),
3055 (
3056 ScalarValue::TimestampMillisecond(Some(2), timezone.clone()),
3057 Some("two"),
3058 ),
3059 (
3060 ScalarValue::TimestampMillisecond(Some(3), timezone.clone()),
3061 Some("three"),
3062 ),
3063 (
3064 ScalarValue::TimestampMillisecond(Some(1), timezone.clone()),
3065 Some("one"),
3066 ),
3067 ]);
3068
3069 let (input_values, expected) =
3070 create_input_and_expected::<TimestampMillisecondArray, StringArray, _, _>([
3071 (1, Some("one")),
3072 (2, Some("two")),
3073 (3, Some("three")),
3074 (3, Some("three")),
3075 (2, Some("two")),
3076 (3, Some("three")),
3077 (5, None), (5, None), (3, Some("three")),
3080 (5, None), ]);
3082
3083 let input_values = input_values.with_timezone_opt(timezone);
3084
3085 test_lookup_eval_with_and_without_else(
3086 &lookup_map,
3087 Arc::new(input_values),
3088 expected,
3089 );
3090 }
3091
3092 #[test]
3093 fn test_with_strings_to_int32() {
3094 let lookup_map = create_lookup([
3095 (Some("why"), Some(42)),
3096 (Some("what"), Some(22)),
3097 (Some("when"), Some(17)),
3098 ]);
3099
3100 let (input_values, expected) =
3101 create_input_and_expected::<StringArray, Int32Array, _, _>([
3102 (Some("why"), Some(42)),
3103 (Some("5"), None), (None, None), (Some("what"), Some(22)),
3106 (None, None), (None, None), (Some("what"), Some(22)),
3109 (Some("5"), None), ]);
3111
3112 let input_values = Arc::new(input_values) as ArrayRef;
3113
3114 test_case_when_literal_lookup(
3116 Arc::clone(&input_values),
3117 &lookup_map,
3118 None,
3119 Arc::new(expected.clone()),
3120 );
3121
3122 let else_value = 101;
3124
3125 let expected_with_else = expected
3127 .iter()
3128 .map(|item| item.unwrap_or(else_value))
3129 .map(Some)
3130 .collect::<Int32Array>();
3131
3132 test_case_when_literal_lookup(
3134 input_values,
3135 &lookup_map,
3136 Some(ScalarValue::Int32(Some(else_value))),
3137 Arc::new(expected_with_else),
3138 );
3139 }
3140
3141 #[test]
3146 fn nested_self_referential_case_hash_stays_bounded() -> Result<()> {
3147 use std::hash::Hasher;
3148
3149 #[derive(Default)]
3150 struct CountingHasher {
3151 write_calls: usize,
3152 bytes_written: usize,
3153 }
3154
3155 impl Hasher for CountingHasher {
3156 fn finish(&self) -> u64 {
3157 0
3158 }
3159
3160 fn write(&mut self, bytes: &[u8]) {
3161 self.write_calls += 1;
3162 self.bytes_written += bytes.len();
3163 }
3164 }
3165
3166 let schema =
3167 Arc::new(Schema::new(vec![Field::new("kind", DataType::Utf8, true)]));
3168
3169 let kind = col("kind", &schema)?;
3170 let mut label = Arc::clone(&kind);
3171
3172 let num_levels = 18;
3173 for idx in 0..num_levels {
3174 let predicate = Arc::new(BinaryExpr::new(
3175 Arc::clone(&kind),
3176 Operator::Eq,
3177 lit(idx.to_string()),
3178 )) as Arc<dyn PhysicalExpr>;
3179
3180 label = case(None, vec![(predicate, lit("label"))], Some(label))?;
3181 }
3182
3183 let mut hasher = CountingHasher::default();
3184 label.hash(&mut hasher);
3185
3186 assert!(
3187 hasher.write_calls < 50_000,
3188 "hashing nested CASE expression took {} hasher writes and {} bytes",
3189 hasher.write_calls,
3190 hasher.bytes_written
3191 );
3192
3193 Ok(())
3194 }
3195}