1use crate::expressions::try_cast;
19use crate::PhysicalExpr;
20use std::borrow::Cow;
21use std::hash::Hash;
22use std::{any::Any, sync::Arc};
23
24use arrow::array::*;
25use arrow::compute::kernels::zip::zip;
26use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter};
27use arrow::datatypes::{DataType, Schema};
28use datafusion_common::cast::as_boolean_array;
29use datafusion_common::{
30 exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
31};
32use datafusion_expr::ColumnarValue;
33
34use super::{Column, Literal};
35use datafusion_physical_expr_common::datum::compare_with_eq;
36use itertools::Itertools;
37
38type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
39
40#[derive(Debug, Hash, PartialEq, Eq)]
41enum EvalMethod {
42 NoExpression,
47 WithExpression,
53 InfallibleExprOrNull,
59 ScalarOrScalar,
64 ExpressionOrExpression,
69}
70
71#[derive(Debug, Hash, PartialEq, Eq)]
89pub struct CaseExpr {
90 expr: Option<Arc<dyn PhysicalExpr>>,
92 when_then_expr: Vec<WhenThen>,
94 else_expr: Option<Arc<dyn PhysicalExpr>>,
96 eval_method: EvalMethod,
98}
99
100impl std::fmt::Display for CaseExpr {
101 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
102 write!(f, "CASE ")?;
103 if let Some(e) = &self.expr {
104 write!(f, "{e} ")?;
105 }
106 for (w, t) in &self.when_then_expr {
107 write!(f, "WHEN {w} THEN {t} ")?;
108 }
109 if let Some(e) = &self.else_expr {
110 write!(f, "ELSE {e} ")?;
111 }
112 write!(f, "END")
113 }
114}
115
116fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) -> bool {
122 expr.as_any().is::<Column>()
123}
124
125impl CaseExpr {
126 pub fn try_new(
128 expr: Option<Arc<dyn PhysicalExpr>>,
129 when_then_expr: Vec<WhenThen>,
130 else_expr: Option<Arc<dyn PhysicalExpr>>,
131 ) -> Result<Self> {
132 let else_expr = match &else_expr {
135 Some(e) => match e.as_any().downcast_ref::<Literal>() {
136 Some(lit) if lit.value().is_null() => None,
137 _ => else_expr,
138 },
139 _ => else_expr,
140 };
141
142 if when_then_expr.is_empty() {
143 exec_err!("There must be at least one WHEN clause")
144 } else {
145 let eval_method = if expr.is_some() {
146 EvalMethod::WithExpression
147 } else if when_then_expr.len() == 1
148 && is_cheap_and_infallible(&(when_then_expr[0].1))
149 && else_expr.is_none()
150 {
151 EvalMethod::InfallibleExprOrNull
152 } else if when_then_expr.len() == 1
153 && when_then_expr[0].1.as_any().is::<Literal>()
154 && else_expr.is_some()
155 && else_expr.as_ref().unwrap().as_any().is::<Literal>()
156 {
157 EvalMethod::ScalarOrScalar
158 } else if when_then_expr.len() == 1
159 && is_cheap_and_infallible(&(when_then_expr[0].1))
160 && else_expr.as_ref().is_some_and(is_cheap_and_infallible)
161 {
162 EvalMethod::ExpressionOrExpression
163 } else {
164 EvalMethod::NoExpression
165 };
166
167 Ok(Self {
168 expr,
169 when_then_expr,
170 else_expr,
171 eval_method,
172 })
173 }
174 }
175
176 pub fn expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
178 self.expr.as_ref()
179 }
180
181 pub fn when_then_expr(&self) -> &[WhenThen] {
183 &self.when_then_expr
184 }
185
186 pub fn else_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
188 self.else_expr.as_ref()
189 }
190}
191
192impl CaseExpr {
193 fn case_when_with_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
201 let return_type = self.data_type(&batch.schema())?;
202 let expr = self.expr.as_ref().unwrap();
203 let base_value = expr.evaluate(batch)?;
204 let base_value = base_value.into_array(batch.num_rows())?;
205 let base_nulls = is_null(base_value.as_ref())?;
206
207 let mut current_value = new_null_array(&return_type, batch.num_rows());
209 let mut remainder = not(&base_nulls)?;
211 for i in 0..self.when_then_expr.len() {
212 let when_value = self.when_then_expr[i]
213 .0
214 .evaluate_selection(batch, &remainder)?;
215 let when_value = when_value.into_array(batch.num_rows())?;
216 let when_match = compare_with_eq(
218 &when_value,
219 &base_value,
220 base_value.data_type().is_nested(),
223 )?;
224 let when_match = match when_match.null_count() {
226 0 => Cow::Borrowed(&when_match),
227 _ => Cow::Owned(prep_null_mask_filter(&when_match)),
228 };
229 let when_match = and(&when_match, &remainder)?;
231
232 if when_match.true_count() == 0 {
234 continue;
235 }
236
237 let then_value = self.when_then_expr[i]
238 .1
239 .evaluate_selection(batch, &when_match)?;
240
241 current_value = match then_value {
242 ColumnarValue::Scalar(ScalarValue::Null) => {
243 nullif(current_value.as_ref(), &when_match)?
244 }
245 ColumnarValue::Scalar(then_value) => {
246 zip(&when_match, &then_value.to_scalar()?, ¤t_value)?
247 }
248 ColumnarValue::Array(then_value) => {
249 zip(&when_match, &then_value, ¤t_value)?
250 }
251 };
252
253 remainder = and_not(&remainder, &when_match)?;
254 }
255
256 if let Some(e) = self.else_expr() {
257 let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
259 remainder = or(&base_nulls, &remainder)?;
261 let else_ = expr
262 .evaluate_selection(batch, &remainder)?
263 .into_array(batch.num_rows())?;
264 current_value = zip(&remainder, &else_, ¤t_value)?;
265 }
266
267 Ok(ColumnarValue::Array(current_value))
268 }
269
270 fn case_when_no_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
278 let return_type = self.data_type(&batch.schema())?;
279
280 let mut current_value = new_null_array(&return_type, batch.num_rows());
282 let mut remainder = BooleanArray::from(vec![true; batch.num_rows()]);
283 for i in 0..self.when_then_expr.len() {
284 let when_value = self.when_then_expr[i]
285 .0
286 .evaluate_selection(batch, &remainder)?;
287 let when_value = when_value.into_array(batch.num_rows())?;
288 let when_value = as_boolean_array(&when_value).map_err(|_| {
289 internal_datafusion_err!("WHEN expression did not return a BooleanArray")
290 })?;
291 let when_value = match when_value.null_count() {
293 0 => Cow::Borrowed(when_value),
294 _ => Cow::Owned(prep_null_mask_filter(when_value)),
295 };
296 let when_value = and(&when_value, &remainder)?;
298
299 if when_value.true_count() == 0 {
301 continue;
302 }
303
304 let then_value = self.when_then_expr[i]
305 .1
306 .evaluate_selection(batch, &when_value)?;
307
308 current_value = match then_value {
309 ColumnarValue::Scalar(ScalarValue::Null) => {
310 nullif(current_value.as_ref(), &when_value)?
311 }
312 ColumnarValue::Scalar(then_value) => {
313 zip(&when_value, &then_value.to_scalar()?, ¤t_value)?
314 }
315 ColumnarValue::Array(then_value) => {
316 zip(&when_value, &then_value, ¤t_value)?
317 }
318 };
319
320 remainder = and_not(&remainder, &when_value)?;
323 }
324
325 if let Some(e) = self.else_expr() {
326 let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
328 let else_ = expr
329 .evaluate_selection(batch, &remainder)?
330 .into_array(batch.num_rows())?;
331 current_value = zip(&remainder, &else_, ¤t_value)?;
332 }
333
334 Ok(ColumnarValue::Array(current_value))
335 }
336
337 fn case_column_or_null(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
347 let when_expr = &self.when_then_expr[0].0;
348 let then_expr = &self.when_then_expr[0].1;
349
350 match when_expr.evaluate(batch)? {
351 ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))) => {
353 then_expr.evaluate(batch)
354 }
355 ColumnarValue::Scalar(_) => {
357 ScalarValue::try_from(self.data_type(&batch.schema())?)
359 .map(ColumnarValue::Scalar)
360 }
361 ColumnarValue::Array(bit_mask) => {
363 let bit_mask = bit_mask
364 .as_any()
365 .downcast_ref::<BooleanArray>()
366 .expect("predicate should evaluate to a boolean array");
367 let bit_mask = match bit_mask.null_count() {
369 0 => not(bit_mask)?,
370 _ => not(&prep_null_mask_filter(bit_mask))?,
371 };
372 match then_expr.evaluate(batch)? {
373 ColumnarValue::Array(array) => {
374 Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?))
375 }
376 ColumnarValue::Scalar(_) => {
377 internal_err!("expression did not evaluate to an array")
378 }
379 }
380 }
381 }
382 }
383
384 fn scalar_or_scalar(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
385 let return_type = self.data_type(&batch.schema())?;
386
387 let when_value = self.when_then_expr[0].0.evaluate(batch)?;
389 let when_value = when_value.into_array(batch.num_rows())?;
390 let when_value = as_boolean_array(&when_value).map_err(|_| {
391 internal_datafusion_err!("WHEN expression did not return a BooleanArray")
392 })?;
393
394 let when_value = match when_value.null_count() {
396 0 => Cow::Borrowed(when_value),
397 _ => Cow::Owned(prep_null_mask_filter(when_value)),
398 };
399
400 let then_value = self.when_then_expr[0].1.evaluate(batch)?;
402 let then_value = Scalar::new(then_value.into_array(1)?);
403
404 let Some(e) = self.else_expr() else {
405 return internal_err!("expression did not evaluate to an array");
406 };
407 let expr = try_cast(Arc::clone(e), &batch.schema(), return_type)?;
409 let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?);
410 Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?))
411 }
412
413 fn expr_or_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
414 let return_type = self.data_type(&batch.schema())?;
415
416 let when_value = self.when_then_expr[0].0.evaluate(batch)?;
418 let when_value = when_value.into_array(batch.num_rows())?;
419 let when_value = as_boolean_array(&when_value).map_err(|e| {
420 DataFusionError::Context(
421 "WHEN expression did not return a BooleanArray".to_string(),
422 Box::new(e),
423 )
424 })?;
425
426 let when_value = match when_value.null_count() {
428 0 => Cow::Borrowed(when_value),
429 _ => Cow::Owned(prep_null_mask_filter(when_value)),
430 };
431
432 let then_value = self.when_then_expr[0]
433 .1
434 .evaluate_selection(batch, &when_value)?
435 .into_array(batch.num_rows())?;
436
437 let remainder = not(&when_value)?;
439 let e = self.else_expr.as_ref().unwrap();
440 let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())
442 .unwrap_or_else(|_| Arc::clone(e));
443 let else_ = expr
444 .evaluate_selection(batch, &remainder)?
445 .into_array(batch.num_rows())?;
446
447 Ok(ColumnarValue::Array(zip(&remainder, &else_, &then_value)?))
448 }
449}
450
451impl PhysicalExpr for CaseExpr {
452 fn as_any(&self) -> &dyn Any {
454 self
455 }
456
457 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
458 let mut data_type = DataType::Null;
461 for i in 0..self.when_then_expr.len() {
462 data_type = self.when_then_expr[i].1.data_type(input_schema)?;
463 if !data_type.equals_datatype(&DataType::Null) {
464 break;
465 }
466 }
467 if data_type.equals_datatype(&DataType::Null) {
469 if let Some(e) = &self.else_expr {
470 data_type = e.data_type(input_schema)?;
471 }
472 }
473
474 Ok(data_type)
475 }
476
477 fn nullable(&self, input_schema: &Schema) -> Result<bool> {
478 let then_nullable = self
480 .when_then_expr
481 .iter()
482 .map(|(_, t)| t.nullable(input_schema))
483 .collect::<Result<Vec<_>>>()?;
484 if then_nullable.contains(&true) {
485 Ok(true)
486 } else if let Some(e) = &self.else_expr {
487 e.nullable(input_schema)
488 } else {
489 Ok(true)
492 }
493 }
494
495 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
496 match self.eval_method {
497 EvalMethod::WithExpression => {
498 self.case_when_with_expr(batch)
501 }
502 EvalMethod::NoExpression => {
503 self.case_when_no_expr(batch)
506 }
507 EvalMethod::InfallibleExprOrNull => {
508 self.case_column_or_null(batch)
510 }
511 EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch),
512 EvalMethod::ExpressionOrExpression => self.expr_or_expr(batch),
513 }
514 }
515
516 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
517 let mut children = vec![];
518 if let Some(expr) = &self.expr {
519 children.push(expr)
520 }
521 self.when_then_expr.iter().for_each(|(cond, value)| {
522 children.push(cond);
523 children.push(value);
524 });
525
526 if let Some(else_expr) = &self.else_expr {
527 children.push(else_expr)
528 }
529 children
530 }
531
532 fn with_new_children(
534 self: Arc<Self>,
535 children: Vec<Arc<dyn PhysicalExpr>>,
536 ) -> Result<Arc<dyn PhysicalExpr>> {
537 if children.len() != self.children().len() {
538 internal_err!("CaseExpr: Wrong number of children")
539 } else {
540 let (expr, when_then_expr, else_expr) =
541 match (self.expr().is_some(), self.else_expr().is_some()) {
542 (true, true) => (
543 Some(&children[0]),
544 &children[1..children.len() - 1],
545 Some(&children[children.len() - 1]),
546 ),
547 (true, false) => {
548 (Some(&children[0]), &children[1..children.len()], None)
549 }
550 (false, true) => (
551 None,
552 &children[0..children.len() - 1],
553 Some(&children[children.len() - 1]),
554 ),
555 (false, false) => (None, &children[0..children.len()], None),
556 };
557 Ok(Arc::new(CaseExpr::try_new(
558 expr.cloned(),
559 when_then_expr.iter().cloned().tuples().collect(),
560 else_expr.cloned(),
561 )?))
562 }
563 }
564
565 fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
566 write!(f, "CASE ")?;
567 if let Some(e) = &self.expr {
568 e.fmt_sql(f)?;
569 write!(f, " ")?;
570 }
571
572 for (w, t) in &self.when_then_expr {
573 write!(f, "WHEN ")?;
574 w.fmt_sql(f)?;
575 write!(f, " THEN ")?;
576 t.fmt_sql(f)?;
577 write!(f, " ")?;
578 }
579
580 if let Some(e) = &self.else_expr {
581 write!(f, "ELSE ")?;
582 e.fmt_sql(f)?;
583 write!(f, " ")?;
584 }
585 write!(f, "END")
586 }
587}
588
589pub fn case(
591 expr: Option<Arc<dyn PhysicalExpr>>,
592 when_thens: Vec<WhenThen>,
593 else_expr: Option<Arc<dyn PhysicalExpr>>,
594) -> Result<Arc<dyn PhysicalExpr>> {
595 Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?))
596}
597
598#[cfg(test)]
599mod tests {
600 use super::*;
601
602 use crate::expressions::{binary, cast, col, lit, BinaryExpr};
603 use arrow::buffer::Buffer;
604 use arrow::datatypes::DataType::Float64;
605 use arrow::datatypes::Field;
606 use datafusion_common::cast::{as_float64_array, as_int32_array};
607 use datafusion_common::plan_err;
608 use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
609 use datafusion_expr::type_coercion::binary::comparison_coercion;
610 use datafusion_expr::Operator;
611 use datafusion_physical_expr_common::physical_expr::fmt_sql;
612
613 #[test]
614 fn case_with_expr() -> Result<()> {
615 let batch = case_test_batch()?;
616 let schema = batch.schema();
617
618 let when1 = lit("foo");
620 let then1 = lit(123i32);
621 let when2 = lit("bar");
622 let then2 = lit(456i32);
623
624 let expr = generate_case_when_with_type_coercion(
625 Some(col("a", &schema)?),
626 vec![(when1, then1), (when2, then2)],
627 None,
628 schema.as_ref(),
629 )?;
630 let result = expr
631 .evaluate(&batch)?
632 .into_array(batch.num_rows())
633 .expect("Failed to convert to array");
634 let result = as_int32_array(&result)?;
635
636 let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
637
638 assert_eq!(expected, result);
639
640 Ok(())
641 }
642
643 #[test]
644 fn case_with_expr_else() -> Result<()> {
645 let batch = case_test_batch()?;
646 let schema = batch.schema();
647
648 let when1 = lit("foo");
650 let then1 = lit(123i32);
651 let when2 = lit("bar");
652 let then2 = lit(456i32);
653 let else_value = lit(999i32);
654
655 let expr = generate_case_when_with_type_coercion(
656 Some(col("a", &schema)?),
657 vec![(when1, then1), (when2, then2)],
658 Some(else_value),
659 schema.as_ref(),
660 )?;
661 let result = expr
662 .evaluate(&batch)?
663 .into_array(batch.num_rows())
664 .expect("Failed to convert to array");
665 let result = as_int32_array(&result)?;
666
667 let expected =
668 &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
669
670 assert_eq!(expected, result);
671
672 Ok(())
673 }
674
675 #[test]
676 fn case_with_expr_divide_by_zero() -> Result<()> {
677 let batch = case_test_batch1()?;
678 let schema = batch.schema();
679
680 let when1 = lit(0i32);
682 let then1 = lit(ScalarValue::Float64(None));
683 let else_value = binary(
684 lit(25.0f64),
685 Operator::Divide,
686 cast(col("a", &schema)?, &batch.schema(), Float64)?,
687 &batch.schema(),
688 )?;
689
690 let expr = generate_case_when_with_type_coercion(
691 Some(col("a", &schema)?),
692 vec![(when1, then1)],
693 Some(else_value),
694 schema.as_ref(),
695 )?;
696 let result = expr
697 .evaluate(&batch)?
698 .into_array(batch.num_rows())
699 .expect("Failed to convert to array");
700 let result =
701 as_float64_array(&result).expect("failed to downcast to Float64Array");
702
703 let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
704
705 assert_eq!(expected, result);
706
707 Ok(())
708 }
709
710 #[test]
711 fn case_without_expr() -> Result<()> {
712 let batch = case_test_batch()?;
713 let schema = batch.schema();
714
715 let when1 = binary(
717 col("a", &schema)?,
718 Operator::Eq,
719 lit("foo"),
720 &batch.schema(),
721 )?;
722 let then1 = lit(123i32);
723 let when2 = binary(
724 col("a", &schema)?,
725 Operator::Eq,
726 lit("bar"),
727 &batch.schema(),
728 )?;
729 let then2 = lit(456i32);
730
731 let expr = generate_case_when_with_type_coercion(
732 None,
733 vec![(when1, then1), (when2, then2)],
734 None,
735 schema.as_ref(),
736 )?;
737 let result = expr
738 .evaluate(&batch)?
739 .into_array(batch.num_rows())
740 .expect("Failed to convert to array");
741 let result = as_int32_array(&result)?;
742
743 let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
744
745 assert_eq!(expected, result);
746
747 Ok(())
748 }
749
750 #[test]
751 fn case_with_expr_when_null() -> Result<()> {
752 let batch = case_test_batch()?;
753 let schema = batch.schema();
754
755 let when1 = lit(ScalarValue::Utf8(None));
757 let then1 = lit(0i32);
758 let when2 = col("a", &schema)?;
759 let then2 = lit(123i32);
760 let else_value = lit(999i32);
761
762 let expr = generate_case_when_with_type_coercion(
763 Some(col("a", &schema)?),
764 vec![(when1, then1), (when2, then2)],
765 Some(else_value),
766 schema.as_ref(),
767 )?;
768 let result = expr
769 .evaluate(&batch)?
770 .into_array(batch.num_rows())
771 .expect("Failed to convert to array");
772 let result = as_int32_array(&result)?;
773
774 let expected =
775 &Int32Array::from(vec![Some(123), Some(123), Some(999), Some(123)]);
776
777 assert_eq!(expected, result);
778
779 Ok(())
780 }
781
782 #[test]
783 fn case_without_expr_divide_by_zero() -> Result<()> {
784 let batch = case_test_batch1()?;
785 let schema = batch.schema();
786
787 let when1 = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &batch.schema())?;
789 let then1 = binary(
790 lit(25.0f64),
791 Operator::Divide,
792 cast(col("a", &schema)?, &batch.schema(), Float64)?,
793 &batch.schema(),
794 )?;
795 let x = lit(ScalarValue::Float64(None));
796
797 let expr = generate_case_when_with_type_coercion(
798 None,
799 vec![(when1, then1)],
800 Some(x),
801 schema.as_ref(),
802 )?;
803 let result = expr
804 .evaluate(&batch)?
805 .into_array(batch.num_rows())
806 .expect("Failed to convert to array");
807 let result =
808 as_float64_array(&result).expect("failed to downcast to Float64Array");
809
810 let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
811
812 assert_eq!(expected, result);
813
814 Ok(())
815 }
816
817 fn case_test_batch1() -> Result<RecordBatch> {
818 let schema = Schema::new(vec![
819 Field::new("a", DataType::Int32, true),
820 Field::new("b", DataType::Int32, true),
821 Field::new("c", DataType::Int32, true),
822 ]);
823 let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]);
824 let b = Int32Array::from(vec![Some(3), None, Some(14), Some(7)]);
825 let c = Int32Array::from(vec![Some(0), Some(-3), Some(777), None]);
826 let batch = RecordBatch::try_new(
827 Arc::new(schema),
828 vec![Arc::new(a), Arc::new(b), Arc::new(c)],
829 )?;
830 Ok(batch)
831 }
832
833 #[test]
834 fn case_without_expr_else() -> Result<()> {
835 let batch = case_test_batch()?;
836 let schema = batch.schema();
837
838 let when1 = binary(
840 col("a", &schema)?,
841 Operator::Eq,
842 lit("foo"),
843 &batch.schema(),
844 )?;
845 let then1 = lit(123i32);
846 let when2 = binary(
847 col("a", &schema)?,
848 Operator::Eq,
849 lit("bar"),
850 &batch.schema(),
851 )?;
852 let then2 = lit(456i32);
853 let else_value = lit(999i32);
854
855 let expr = generate_case_when_with_type_coercion(
856 None,
857 vec![(when1, then1), (when2, then2)],
858 Some(else_value),
859 schema.as_ref(),
860 )?;
861 let result = expr
862 .evaluate(&batch)?
863 .into_array(batch.num_rows())
864 .expect("Failed to convert to array");
865 let result = as_int32_array(&result)?;
866
867 let expected =
868 &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
869
870 assert_eq!(expected, result);
871
872 Ok(())
873 }
874
875 #[test]
876 fn case_with_type_cast() -> Result<()> {
877 let batch = case_test_batch()?;
878 let schema = batch.schema();
879
880 let when = binary(
882 col("a", &schema)?,
883 Operator::Eq,
884 lit("foo"),
885 &batch.schema(),
886 )?;
887 let then = lit(123.3f64);
888 let else_value = lit(999i32);
889
890 let expr = generate_case_when_with_type_coercion(
891 None,
892 vec![(when, then)],
893 Some(else_value),
894 schema.as_ref(),
895 )?;
896 let result = expr
897 .evaluate(&batch)?
898 .into_array(batch.num_rows())
899 .expect("Failed to convert to array");
900 let result =
901 as_float64_array(&result).expect("failed to downcast to Float64Array");
902
903 let expected =
904 &Float64Array::from(vec![Some(123.3), Some(999.0), Some(999.0), Some(999.0)]);
905
906 assert_eq!(expected, result);
907
908 Ok(())
909 }
910
911 #[test]
912 fn case_with_matches_and_nulls() -> Result<()> {
913 let batch = case_test_batch_nulls()?;
914 let schema = batch.schema();
915
916 let when = binary(
918 col("load4", &schema)?,
919 Operator::Eq,
920 lit(1.77f64),
921 &batch.schema(),
922 )?;
923 let then = col("load4", &schema)?;
924
925 let expr = generate_case_when_with_type_coercion(
926 None,
927 vec![(when, then)],
928 None,
929 schema.as_ref(),
930 )?;
931 let result = expr
932 .evaluate(&batch)?
933 .into_array(batch.num_rows())
934 .expect("Failed to convert to array");
935 let result =
936 as_float64_array(&result).expect("failed to downcast to Float64Array");
937
938 let expected =
939 &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
940
941 assert_eq!(expected, result);
942
943 Ok(())
944 }
945
946 #[test]
947 fn case_with_scalar_predicate() -> Result<()> {
948 let batch = case_test_batch_nulls()?;
949 let schema = batch.schema();
950
951 let when = lit(true);
953 let then = col("load4", &schema)?;
954 let expr = generate_case_when_with_type_coercion(
955 None,
956 vec![(when, then)],
957 None,
958 schema.as_ref(),
959 )?;
960
961 let result = expr
963 .evaluate(&batch)?
964 .into_array(batch.num_rows())
965 .expect("Failed to convert to array");
966 let result =
967 as_float64_array(&result).expect("failed to downcast to Float64Array");
968 let expected = &Float64Array::from(vec![
969 Some(1.77),
970 None,
971 None,
972 Some(1.78),
973 None,
974 Some(1.77),
975 ]);
976 assert_eq!(expected, result);
977
978 let expected = Float64Array::from(vec![Some(1.1)]);
980 let batch =
981 RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(expected.clone())])?;
982 let result = expr
983 .evaluate(&batch)?
984 .into_array(batch.num_rows())
985 .expect("Failed to convert to array");
986 let result =
987 as_float64_array(&result).expect("failed to downcast to Float64Array");
988 assert_eq!(&expected, result);
989
990 Ok(())
991 }
992
993 #[test]
994 fn case_expr_matches_and_nulls() -> Result<()> {
995 let batch = case_test_batch_nulls()?;
996 let schema = batch.schema();
997
998 let expr = col("load4", &schema)?;
1000 let when = lit(1.77f64);
1001 let then = col("load4", &schema)?;
1002
1003 let expr = generate_case_when_with_type_coercion(
1004 Some(expr),
1005 vec![(when, then)],
1006 None,
1007 schema.as_ref(),
1008 )?;
1009 let result = expr
1010 .evaluate(&batch)?
1011 .into_array(batch.num_rows())
1012 .expect("Failed to convert to array");
1013 let result =
1014 as_float64_array(&result).expect("failed to downcast to Float64Array");
1015
1016 let expected =
1017 &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
1018
1019 assert_eq!(expected, result);
1020
1021 Ok(())
1022 }
1023
1024 #[test]
1025 fn test_when_null_and_some_cond_else_null() -> Result<()> {
1026 let batch = case_test_batch()?;
1027 let schema = batch.schema();
1028
1029 let when = binary(
1030 Arc::new(Literal::new(ScalarValue::Boolean(None))),
1031 Operator::And,
1032 binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?,
1033 &schema,
1034 )?;
1035 let then = col("a", &schema)?;
1036
1037 let expr = Arc::new(CaseExpr::try_new(None, vec![(when, then)], None)?);
1039 let result = expr
1040 .evaluate(&batch)?
1041 .into_array(batch.num_rows())
1042 .expect("Failed to convert to array");
1043 let result = as_string_array(&result);
1044
1045 assert_eq!(result.logical_null_count(), batch.num_rows());
1047 Ok(())
1048 }
1049
1050 fn case_test_batch() -> Result<RecordBatch> {
1051 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
1052 let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
1053 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
1054 Ok(batch)
1055 }
1056
1057 fn case_test_batch_nulls() -> Result<RecordBatch> {
1060 let load4: Float64Array = vec![
1061 Some(1.77), Some(1.77), Some(1.77), Some(1.78), None, Some(1.77), ]
1068 .into_iter()
1069 .collect();
1070
1071 let null_buffer = Buffer::from([0b00101001u8]);
1073 let load4 = load4
1074 .into_data()
1075 .into_builder()
1076 .null_bit_buffer(Some(null_buffer))
1077 .build()
1078 .unwrap();
1079 let load4: Float64Array = load4.into();
1080
1081 let batch =
1082 RecordBatch::try_from_iter(vec![("load4", Arc::new(load4) as ArrayRef)])?;
1083 Ok(batch)
1084 }
1085
1086 #[test]
1087 fn case_test_incompatible() -> Result<()> {
1088 let batch = case_test_batch()?;
1091 let schema = batch.schema();
1092
1093 let when1 = binary(
1095 col("a", &schema)?,
1096 Operator::Eq,
1097 lit("foo"),
1098 &batch.schema(),
1099 )?;
1100 let then1 = lit(123i32);
1101 let when2 = binary(
1102 col("a", &schema)?,
1103 Operator::Eq,
1104 lit("bar"),
1105 &batch.schema(),
1106 )?;
1107 let then2 = lit(true);
1108
1109 let expr = generate_case_when_with_type_coercion(
1110 None,
1111 vec![(when1, then1), (when2, then2)],
1112 None,
1113 schema.as_ref(),
1114 );
1115 assert!(expr.is_err());
1116
1117 let when1 = binary(
1122 col("a", &schema)?,
1123 Operator::Eq,
1124 lit("foo"),
1125 &batch.schema(),
1126 )?;
1127 let then1 = lit(123i32);
1128 let when2 = binary(
1129 col("a", &schema)?,
1130 Operator::Eq,
1131 lit("bar"),
1132 &batch.schema(),
1133 )?;
1134 let then2 = lit(456i64);
1135 let else_expr = lit(1.23f64);
1136
1137 let expr = generate_case_when_with_type_coercion(
1138 None,
1139 vec![(when1, then1), (when2, then2)],
1140 Some(else_expr),
1141 schema.as_ref(),
1142 );
1143 assert!(expr.is_ok());
1144 let result_type = expr.unwrap().data_type(schema.as_ref())?;
1145 assert_eq!(Float64, result_type);
1146 Ok(())
1147 }
1148
1149 #[test]
1150 fn case_eq() -> Result<()> {
1151 let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1152
1153 let when1 = lit("foo");
1154 let then1 = lit(123i32);
1155 let when2 = lit("bar");
1156 let then2 = lit(456i32);
1157 let else_value = lit(999i32);
1158
1159 let expr1 = generate_case_when_with_type_coercion(
1160 Some(col("a", &schema)?),
1161 vec![
1162 (Arc::clone(&when1), Arc::clone(&then1)),
1163 (Arc::clone(&when2), Arc::clone(&then2)),
1164 ],
1165 Some(Arc::clone(&else_value)),
1166 &schema,
1167 )?;
1168
1169 let expr2 = generate_case_when_with_type_coercion(
1170 Some(col("a", &schema)?),
1171 vec![
1172 (Arc::clone(&when1), Arc::clone(&then1)),
1173 (Arc::clone(&when2), Arc::clone(&then2)),
1174 ],
1175 Some(Arc::clone(&else_value)),
1176 &schema,
1177 )?;
1178
1179 let expr3 = generate_case_when_with_type_coercion(
1180 Some(col("a", &schema)?),
1181 vec![(Arc::clone(&when1), Arc::clone(&then1)), (when2, then2)],
1182 None,
1183 &schema,
1184 )?;
1185
1186 let expr4 = generate_case_when_with_type_coercion(
1187 Some(col("a", &schema)?),
1188 vec![(when1, then1)],
1189 Some(else_value),
1190 &schema,
1191 )?;
1192
1193 assert!(expr1.eq(&expr2));
1194 assert!(expr2.eq(&expr1));
1195
1196 assert!(expr2.ne(&expr3));
1197 assert!(expr3.ne(&expr2));
1198
1199 assert!(expr1.ne(&expr4));
1200 assert!(expr4.ne(&expr1));
1201
1202 Ok(())
1203 }
1204
1205 #[test]
1206 fn case_transform() -> Result<()> {
1207 let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1208
1209 let when1 = lit("foo");
1210 let then1 = lit(123i32);
1211 let when2 = lit("bar");
1212 let then2 = lit(456i32);
1213 let else_value = lit(999i32);
1214
1215 let expr = generate_case_when_with_type_coercion(
1216 Some(col("a", &schema)?),
1217 vec![
1218 (Arc::clone(&when1), Arc::clone(&then1)),
1219 (Arc::clone(&when2), Arc::clone(&then2)),
1220 ],
1221 Some(Arc::clone(&else_value)),
1222 &schema,
1223 )?;
1224
1225 let expr2 = Arc::clone(&expr)
1226 .transform(|e| {
1227 let transformed = match e.as_any().downcast_ref::<Literal>() {
1228 Some(lit_value) => match lit_value.value() {
1229 ScalarValue::Utf8(Some(str_value)) => {
1230 Some(lit(str_value.to_uppercase()))
1231 }
1232 _ => None,
1233 },
1234 _ => None,
1235 };
1236 Ok(if let Some(transformed) = transformed {
1237 Transformed::yes(transformed)
1238 } else {
1239 Transformed::no(e)
1240 })
1241 })
1242 .data()
1243 .unwrap();
1244
1245 let expr3 = Arc::clone(&expr)
1246 .transform_down(|e| {
1247 let transformed = match e.as_any().downcast_ref::<Literal>() {
1248 Some(lit_value) => match lit_value.value() {
1249 ScalarValue::Utf8(Some(str_value)) => {
1250 Some(lit(str_value.to_uppercase()))
1251 }
1252 _ => None,
1253 },
1254 _ => None,
1255 };
1256 Ok(if let Some(transformed) = transformed {
1257 Transformed::yes(transformed)
1258 } else {
1259 Transformed::no(e)
1260 })
1261 })
1262 .data()
1263 .unwrap();
1264
1265 assert!(expr.ne(&expr2));
1266 assert!(expr2.eq(&expr3));
1267
1268 Ok(())
1269 }
1270
1271 #[test]
1272 fn test_column_or_null_specialization() -> Result<()> {
1273 let mut c1 = Int32Builder::new();
1275 let mut c2 = StringBuilder::new();
1276 for i in 0..1000 {
1277 c1.append_value(i);
1278 if i % 7 == 0 {
1279 c2.append_null();
1280 } else {
1281 c2.append_value(format!("string {i}"));
1282 }
1283 }
1284 let c1 = Arc::new(c1.finish());
1285 let c2 = Arc::new(c2.finish());
1286 let schema = Schema::new(vec![
1287 Field::new("c1", DataType::Int32, true),
1288 Field::new("c2", DataType::Utf8, true),
1289 ]);
1290 let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap();
1291
1292 let predicate = Arc::new(BinaryExpr::new(
1294 make_col("c1", 0),
1295 Operator::LtEq,
1296 make_lit_i32(250),
1297 ));
1298 let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?;
1299 assert!(matches!(expr.eval_method, EvalMethod::InfallibleExprOrNull));
1300 match expr.evaluate(&batch)? {
1301 ColumnarValue::Array(array) => {
1302 assert_eq!(1000, array.len());
1303 assert_eq!(785, array.null_count());
1304 }
1305 _ => unreachable!(),
1306 }
1307 Ok(())
1308 }
1309
1310 #[test]
1311 fn test_expr_or_expr_specialization() -> Result<()> {
1312 let batch = case_test_batch1()?;
1313 let schema = batch.schema();
1314 let when = binary(
1315 col("a", &schema)?,
1316 Operator::LtEq,
1317 lit(2i32),
1318 &batch.schema(),
1319 )?;
1320 let then = col("b", &schema)?;
1321 let else_expr = col("c", &schema)?;
1322 let expr = CaseExpr::try_new(None, vec![(when, then)], Some(else_expr))?;
1323 assert!(matches!(
1324 expr.eval_method,
1325 EvalMethod::ExpressionOrExpression
1326 ));
1327 let result = expr
1328 .evaluate(&batch)?
1329 .into_array(batch.num_rows())
1330 .expect("Failed to convert to array");
1331 let result = as_int32_array(&result).expect("failed to downcast to Int32Array");
1332
1333 let expected = &Int32Array::from(vec![Some(3), None, Some(777), None]);
1334
1335 assert_eq!(expected, result);
1336 Ok(())
1337 }
1338
1339 fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
1340 Arc::new(Column::new(name, index))
1341 }
1342
1343 fn make_lit_i32(n: i32) -> Arc<dyn PhysicalExpr> {
1344 Arc::new(Literal::new(ScalarValue::Int32(Some(n))))
1345 }
1346
1347 fn generate_case_when_with_type_coercion(
1348 expr: Option<Arc<dyn PhysicalExpr>>,
1349 when_thens: Vec<WhenThen>,
1350 else_expr: Option<Arc<dyn PhysicalExpr>>,
1351 input_schema: &Schema,
1352 ) -> Result<Arc<dyn PhysicalExpr>> {
1353 let coerce_type =
1354 get_case_common_type(&when_thens, else_expr.clone(), input_schema);
1355 let (when_thens, else_expr) = match coerce_type {
1356 None => plan_err!(
1357 "Can't get a common type for then {when_thens:?} and else {else_expr:?} expression"
1358 ),
1359 Some(data_type) => {
1360 let left = when_thens
1362 .into_iter()
1363 .map(|(when, then)| {
1364 let then = try_cast(then, input_schema, data_type.clone())?;
1365 Ok((when, then))
1366 })
1367 .collect::<Result<Vec<_>>>()?;
1368 let right = match else_expr {
1369 None => None,
1370 Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?),
1371 };
1372
1373 Ok((left, right))
1374 }
1375 }?;
1376 case(expr, when_thens, else_expr)
1377 }
1378
1379 fn get_case_common_type(
1380 when_thens: &[WhenThen],
1381 else_expr: Option<Arc<dyn PhysicalExpr>>,
1382 input_schema: &Schema,
1383 ) -> Option<DataType> {
1384 let thens_type = when_thens
1385 .iter()
1386 .map(|when_then| {
1387 let data_type = &when_then.1.data_type(input_schema).unwrap();
1388 data_type.clone()
1389 })
1390 .collect::<Vec<_>>();
1391 let else_type = match else_expr {
1392 None => {
1393 thens_type[0].clone()
1395 }
1396 Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(),
1397 };
1398 thens_type
1399 .iter()
1400 .try_fold(else_type, |left_type, right_type| {
1401 comparison_coercion(&left_type, right_type)
1404 })
1405 }
1406
1407 #[test]
1408 fn test_fmt_sql() -> Result<()> {
1409 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
1410
1411 let when = binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?;
1413 let then = lit(123.3f64);
1414 let else_value = lit(999i32);
1415
1416 let expr = generate_case_when_with_type_coercion(
1417 None,
1418 vec![(when, then)],
1419 Some(else_value),
1420 &schema,
1421 )?;
1422
1423 let display_string = expr.to_string();
1424 assert_eq!(
1425 display_string,
1426 "CASE WHEN a@0 = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END"
1427 );
1428
1429 let sql_string = fmt_sql(expr.as_ref()).to_string();
1430 assert_eq!(
1431 sql_string,
1432 "CASE WHEN a = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END"
1433 );
1434
1435 Ok(())
1436 }
1437}