1use crate::expressions::try_cast;
19use crate::PhysicalExpr;
20use std::borrow::Cow;
21use std::hash::Hash;
22use std::{any::Any, sync::Arc};
23
24use arrow::array::*;
25use arrow::compute::kernels::zip::zip;
26use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter};
27use arrow::datatypes::{DataType, Schema};
28use datafusion_common::cast::as_boolean_array;
29use datafusion_common::{
30 exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
31};
32use datafusion_expr::ColumnarValue;
33
34use super::{Column, Literal};
35use datafusion_physical_expr_common::datum::compare_with_eq;
36use itertools::Itertools;
37
38type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
39
40#[derive(Debug, Hash, PartialEq, Eq)]
41enum EvalMethod {
42 NoExpression,
47 WithExpression,
53 InfallibleExprOrNull,
59 ScalarOrScalar,
64 ExpressionOrExpression,
69}
70
71#[derive(Debug, Hash, PartialEq, Eq)]
89pub struct CaseExpr {
90 expr: Option<Arc<dyn PhysicalExpr>>,
92 when_then_expr: Vec<WhenThen>,
94 else_expr: Option<Arc<dyn PhysicalExpr>>,
96 eval_method: EvalMethod,
98}
99
100impl std::fmt::Display for CaseExpr {
101 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
102 write!(f, "CASE ")?;
103 if let Some(e) = &self.expr {
104 write!(f, "{e} ")?;
105 }
106 for (w, t) in &self.when_then_expr {
107 write!(f, "WHEN {w} THEN {t} ")?;
108 }
109 if let Some(e) = &self.else_expr {
110 write!(f, "ELSE {e} ")?;
111 }
112 write!(f, "END")
113 }
114}
115
116fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) -> bool {
122 expr.as_any().is::<Column>()
123}
124
125impl CaseExpr {
126 pub fn try_new(
128 expr: Option<Arc<dyn PhysicalExpr>>,
129 when_then_expr: Vec<WhenThen>,
130 else_expr: Option<Arc<dyn PhysicalExpr>>,
131 ) -> Result<Self> {
132 let else_expr = match &else_expr {
135 Some(e) => match e.as_any().downcast_ref::<Literal>() {
136 Some(lit) if lit.value().is_null() => None,
137 _ => else_expr,
138 },
139 _ => else_expr,
140 };
141
142 if when_then_expr.is_empty() {
143 exec_err!("There must be at least one WHEN clause")
144 } else {
145 let eval_method = if expr.is_some() {
146 EvalMethod::WithExpression
147 } else if when_then_expr.len() == 1
148 && is_cheap_and_infallible(&(when_then_expr[0].1))
149 && else_expr.is_none()
150 {
151 EvalMethod::InfallibleExprOrNull
152 } else if when_then_expr.len() == 1
153 && when_then_expr[0].1.as_any().is::<Literal>()
154 && else_expr.is_some()
155 && else_expr.as_ref().unwrap().as_any().is::<Literal>()
156 {
157 EvalMethod::ScalarOrScalar
158 } else if when_then_expr.len() == 1
159 && is_cheap_and_infallible(&(when_then_expr[0].1))
160 && else_expr.as_ref().is_some_and(is_cheap_and_infallible)
161 {
162 EvalMethod::ExpressionOrExpression
163 } else {
164 EvalMethod::NoExpression
165 };
166
167 Ok(Self {
168 expr,
169 when_then_expr,
170 else_expr,
171 eval_method,
172 })
173 }
174 }
175
176 pub fn expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
178 self.expr.as_ref()
179 }
180
181 pub fn when_then_expr(&self) -> &[WhenThen] {
183 &self.when_then_expr
184 }
185
186 pub fn else_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
188 self.else_expr.as_ref()
189 }
190}
191
192impl CaseExpr {
193 fn case_when_with_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
201 let return_type = self.data_type(&batch.schema())?;
202 let expr = self.expr.as_ref().unwrap();
203 let base_value = expr.evaluate(batch)?;
204 let base_value = base_value.into_array(batch.num_rows())?;
205 let base_nulls = is_null(base_value.as_ref())?;
206
207 let mut current_value = new_null_array(&return_type, batch.num_rows());
209 let mut remainder = not(&base_nulls)?;
211 for i in 0..self.when_then_expr.len() {
212 let when_value = self.when_then_expr[i]
213 .0
214 .evaluate_selection(batch, &remainder)?;
215 let when_value = when_value.into_array(batch.num_rows())?;
216 let when_match = compare_with_eq(
218 &when_value,
219 &base_value,
220 base_value.data_type().is_nested(),
223 )?;
224 let when_match = match when_match.null_count() {
226 0 => Cow::Borrowed(&when_match),
227 _ => Cow::Owned(prep_null_mask_filter(&when_match)),
228 };
229 let when_match = and(&when_match, &remainder)?;
231
232 if when_match.true_count() == 0 {
234 continue;
235 }
236
237 let then_value = self.when_then_expr[i]
238 .1
239 .evaluate_selection(batch, &when_match)?;
240
241 current_value = match then_value {
242 ColumnarValue::Scalar(ScalarValue::Null) => {
243 nullif(current_value.as_ref(), &when_match)?
244 }
245 ColumnarValue::Scalar(then_value) => {
246 zip(&when_match, &then_value.to_scalar()?, ¤t_value)?
247 }
248 ColumnarValue::Array(then_value) => {
249 zip(&when_match, &then_value, ¤t_value)?
250 }
251 };
252
253 remainder = and_not(&remainder, &when_match)?;
254 }
255
256 if let Some(e) = self.else_expr() {
257 let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
259 remainder = or(&base_nulls, &remainder)?;
261 let else_ = expr
262 .evaluate_selection(batch, &remainder)?
263 .into_array(batch.num_rows())?;
264 current_value = zip(&remainder, &else_, ¤t_value)?;
265 }
266
267 Ok(ColumnarValue::Array(current_value))
268 }
269
270 fn case_when_no_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
278 let return_type = self.data_type(&batch.schema())?;
279
280 let mut current_value = new_null_array(&return_type, batch.num_rows());
282 let mut remainder = BooleanArray::from(vec![true; batch.num_rows()]);
283 for i in 0..self.when_then_expr.len() {
284 let when_value = self.when_then_expr[i]
285 .0
286 .evaluate_selection(batch, &remainder)?;
287 let when_value = when_value.into_array(batch.num_rows())?;
288 let when_value = as_boolean_array(&when_value).map_err(|_| {
289 internal_datafusion_err!("WHEN expression did not return a BooleanArray")
290 })?;
291 let when_value = match when_value.null_count() {
293 0 => Cow::Borrowed(when_value),
294 _ => Cow::Owned(prep_null_mask_filter(when_value)),
295 };
296 let when_value = and(&when_value, &remainder)?;
298
299 if when_value.true_count() == 0 {
301 continue;
302 }
303
304 let then_value = self.when_then_expr[i]
305 .1
306 .evaluate_selection(batch, &when_value)?;
307
308 current_value = match then_value {
309 ColumnarValue::Scalar(ScalarValue::Null) => {
310 nullif(current_value.as_ref(), &when_value)?
311 }
312 ColumnarValue::Scalar(then_value) => {
313 zip(&when_value, &then_value.to_scalar()?, ¤t_value)?
314 }
315 ColumnarValue::Array(then_value) => {
316 zip(&when_value, &then_value, ¤t_value)?
317 }
318 };
319
320 remainder = and_not(&remainder, &when_value)?;
323 }
324
325 if let Some(e) = self.else_expr() {
326 if remainder.true_count() > 0 {
327 let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
329 let else_ = expr
330 .evaluate_selection(batch, &remainder)?
331 .into_array(batch.num_rows())?;
332 current_value = zip(&remainder, &else_, ¤t_value)?;
333 }
334 }
335
336 Ok(ColumnarValue::Array(current_value))
337 }
338
339 fn case_column_or_null(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
349 let when_expr = &self.when_then_expr[0].0;
350 let then_expr = &self.when_then_expr[0].1;
351
352 match when_expr.evaluate(batch)? {
353 ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))) => {
355 then_expr.evaluate(batch)
356 }
357 ColumnarValue::Scalar(_) => {
359 ScalarValue::try_from(self.data_type(&batch.schema())?)
361 .map(ColumnarValue::Scalar)
362 }
363 ColumnarValue::Array(bit_mask) => {
365 let bit_mask = bit_mask
366 .as_any()
367 .downcast_ref::<BooleanArray>()
368 .expect("predicate should evaluate to a boolean array");
369 let bit_mask = match bit_mask.null_count() {
371 0 => not(bit_mask)?,
372 _ => not(&prep_null_mask_filter(bit_mask))?,
373 };
374 match then_expr.evaluate(batch)? {
375 ColumnarValue::Array(array) => {
376 Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?))
377 }
378 ColumnarValue::Scalar(_) => {
379 internal_err!("expression did not evaluate to an array")
380 }
381 }
382 }
383 }
384 }
385
386 fn scalar_or_scalar(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
387 let return_type = self.data_type(&batch.schema())?;
388
389 let when_value = self.when_then_expr[0].0.evaluate(batch)?;
391 let when_value = when_value.into_array(batch.num_rows())?;
392 let when_value = as_boolean_array(&when_value).map_err(|_| {
393 internal_datafusion_err!("WHEN expression did not return a BooleanArray")
394 })?;
395
396 let when_value = match when_value.null_count() {
398 0 => Cow::Borrowed(when_value),
399 _ => Cow::Owned(prep_null_mask_filter(when_value)),
400 };
401
402 let then_value = self.when_then_expr[0].1.evaluate(batch)?;
404 let then_value = Scalar::new(then_value.into_array(1)?);
405
406 let Some(e) = self.else_expr() else {
407 return internal_err!("expression did not evaluate to an array");
408 };
409 let expr = try_cast(Arc::clone(e), &batch.schema(), return_type)?;
411 let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?);
412 Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?))
413 }
414
415 fn expr_or_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
416 let return_type = self.data_type(&batch.schema())?;
417
418 let when_value = self.when_then_expr[0].0.evaluate(batch)?;
420 let when_value = when_value.into_array(batch.num_rows())?;
421 let when_value = as_boolean_array(&when_value).map_err(|e| {
422 DataFusionError::Context(
423 "WHEN expression did not return a BooleanArray".to_string(),
424 Box::new(e),
425 )
426 })?;
427
428 let when_value = match when_value.null_count() {
430 0 => Cow::Borrowed(when_value),
431 _ => Cow::Owned(prep_null_mask_filter(when_value)),
432 };
433
434 let then_value = self.when_then_expr[0]
435 .1
436 .evaluate_selection(batch, &when_value)?
437 .into_array(batch.num_rows())?;
438
439 let remainder = not(&when_value)?;
441 let e = self.else_expr.as_ref().unwrap();
442 let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())
444 .unwrap_or_else(|_| Arc::clone(e));
445 let else_ = expr
446 .evaluate_selection(batch, &remainder)?
447 .into_array(batch.num_rows())?;
448
449 Ok(ColumnarValue::Array(zip(&remainder, &else_, &then_value)?))
450 }
451}
452
453impl PhysicalExpr for CaseExpr {
454 fn as_any(&self) -> &dyn Any {
456 self
457 }
458
459 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
460 let mut data_type = DataType::Null;
463 for i in 0..self.when_then_expr.len() {
464 data_type = self.when_then_expr[i].1.data_type(input_schema)?;
465 if !data_type.equals_datatype(&DataType::Null) {
466 break;
467 }
468 }
469 if data_type.equals_datatype(&DataType::Null) {
471 if let Some(e) = &self.else_expr {
472 data_type = e.data_type(input_schema)?;
473 }
474 }
475
476 Ok(data_type)
477 }
478
479 fn nullable(&self, input_schema: &Schema) -> Result<bool> {
480 let then_nullable = self
482 .when_then_expr
483 .iter()
484 .map(|(_, t)| t.nullable(input_schema))
485 .collect::<Result<Vec<_>>>()?;
486 if then_nullable.contains(&true) {
487 Ok(true)
488 } else if let Some(e) = &self.else_expr {
489 e.nullable(input_schema)
490 } else {
491 Ok(true)
494 }
495 }
496
497 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
498 match self.eval_method {
499 EvalMethod::WithExpression => {
500 self.case_when_with_expr(batch)
503 }
504 EvalMethod::NoExpression => {
505 self.case_when_no_expr(batch)
508 }
509 EvalMethod::InfallibleExprOrNull => {
510 self.case_column_or_null(batch)
512 }
513 EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch),
514 EvalMethod::ExpressionOrExpression => self.expr_or_expr(batch),
515 }
516 }
517
518 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
519 let mut children = vec![];
520 if let Some(expr) = &self.expr {
521 children.push(expr)
522 }
523 self.when_then_expr.iter().for_each(|(cond, value)| {
524 children.push(cond);
525 children.push(value);
526 });
527
528 if let Some(else_expr) = &self.else_expr {
529 children.push(else_expr)
530 }
531 children
532 }
533
534 fn with_new_children(
536 self: Arc<Self>,
537 children: Vec<Arc<dyn PhysicalExpr>>,
538 ) -> Result<Arc<dyn PhysicalExpr>> {
539 if children.len() != self.children().len() {
540 internal_err!("CaseExpr: Wrong number of children")
541 } else {
542 let (expr, when_then_expr, else_expr) =
543 match (self.expr().is_some(), self.else_expr().is_some()) {
544 (true, true) => (
545 Some(&children[0]),
546 &children[1..children.len() - 1],
547 Some(&children[children.len() - 1]),
548 ),
549 (true, false) => {
550 (Some(&children[0]), &children[1..children.len()], None)
551 }
552 (false, true) => (
553 None,
554 &children[0..children.len() - 1],
555 Some(&children[children.len() - 1]),
556 ),
557 (false, false) => (None, &children[0..children.len()], None),
558 };
559 Ok(Arc::new(CaseExpr::try_new(
560 expr.cloned(),
561 when_then_expr.iter().cloned().tuples().collect(),
562 else_expr.cloned(),
563 )?))
564 }
565 }
566
567 fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
568 write!(f, "CASE ")?;
569 if let Some(e) = &self.expr {
570 e.fmt_sql(f)?;
571 write!(f, " ")?;
572 }
573
574 for (w, t) in &self.when_then_expr {
575 write!(f, "WHEN ")?;
576 w.fmt_sql(f)?;
577 write!(f, " THEN ")?;
578 t.fmt_sql(f)?;
579 write!(f, " ")?;
580 }
581
582 if let Some(e) = &self.else_expr {
583 write!(f, "ELSE ")?;
584 e.fmt_sql(f)?;
585 write!(f, " ")?;
586 }
587 write!(f, "END")
588 }
589}
590
591pub fn case(
593 expr: Option<Arc<dyn PhysicalExpr>>,
594 when_thens: Vec<WhenThen>,
595 else_expr: Option<Arc<dyn PhysicalExpr>>,
596) -> Result<Arc<dyn PhysicalExpr>> {
597 Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?))
598}
599
600#[cfg(test)]
601mod tests {
602 use super::*;
603
604 use crate::expressions::{binary, cast, col, lit, BinaryExpr};
605 use arrow::buffer::Buffer;
606 use arrow::datatypes::DataType::Float64;
607 use arrow::datatypes::Field;
608 use datafusion_common::cast::{as_float64_array, as_int32_array};
609 use datafusion_common::plan_err;
610 use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
611 use datafusion_expr::type_coercion::binary::comparison_coercion;
612 use datafusion_expr::Operator;
613 use datafusion_physical_expr_common::physical_expr::fmt_sql;
614
615 #[test]
616 fn case_with_expr() -> Result<()> {
617 let batch = case_test_batch()?;
618 let schema = batch.schema();
619
620 let when1 = lit("foo");
622 let then1 = lit(123i32);
623 let when2 = lit("bar");
624 let then2 = lit(456i32);
625
626 let expr = generate_case_when_with_type_coercion(
627 Some(col("a", &schema)?),
628 vec![(when1, then1), (when2, then2)],
629 None,
630 schema.as_ref(),
631 )?;
632 let result = expr
633 .evaluate(&batch)?
634 .into_array(batch.num_rows())
635 .expect("Failed to convert to array");
636 let result = as_int32_array(&result)?;
637
638 let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
639
640 assert_eq!(expected, result);
641
642 Ok(())
643 }
644
645 #[test]
646 fn case_with_expr_else() -> Result<()> {
647 let batch = case_test_batch()?;
648 let schema = batch.schema();
649
650 let when1 = lit("foo");
652 let then1 = lit(123i32);
653 let when2 = lit("bar");
654 let then2 = lit(456i32);
655 let else_value = lit(999i32);
656
657 let expr = generate_case_when_with_type_coercion(
658 Some(col("a", &schema)?),
659 vec![(when1, then1), (when2, then2)],
660 Some(else_value),
661 schema.as_ref(),
662 )?;
663 let result = expr
664 .evaluate(&batch)?
665 .into_array(batch.num_rows())
666 .expect("Failed to convert to array");
667 let result = as_int32_array(&result)?;
668
669 let expected =
670 &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
671
672 assert_eq!(expected, result);
673
674 Ok(())
675 }
676
677 #[test]
678 fn case_with_expr_divide_by_zero() -> Result<()> {
679 let batch = case_test_batch1()?;
680 let schema = batch.schema();
681
682 let when1 = lit(0i32);
684 let then1 = lit(ScalarValue::Float64(None));
685 let else_value = binary(
686 lit(25.0f64),
687 Operator::Divide,
688 cast(col("a", &schema)?, &batch.schema(), Float64)?,
689 &batch.schema(),
690 )?;
691
692 let expr = generate_case_when_with_type_coercion(
693 Some(col("a", &schema)?),
694 vec![(when1, then1)],
695 Some(else_value),
696 schema.as_ref(),
697 )?;
698 let result = expr
699 .evaluate(&batch)?
700 .into_array(batch.num_rows())
701 .expect("Failed to convert to array");
702 let result =
703 as_float64_array(&result).expect("failed to downcast to Float64Array");
704
705 let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
706
707 assert_eq!(expected, result);
708
709 Ok(())
710 }
711
712 #[test]
713 fn case_without_expr() -> Result<()> {
714 let batch = case_test_batch()?;
715 let schema = batch.schema();
716
717 let when1 = binary(
719 col("a", &schema)?,
720 Operator::Eq,
721 lit("foo"),
722 &batch.schema(),
723 )?;
724 let then1 = lit(123i32);
725 let when2 = binary(
726 col("a", &schema)?,
727 Operator::Eq,
728 lit("bar"),
729 &batch.schema(),
730 )?;
731 let then2 = lit(456i32);
732
733 let expr = generate_case_when_with_type_coercion(
734 None,
735 vec![(when1, then1), (when2, then2)],
736 None,
737 schema.as_ref(),
738 )?;
739 let result = expr
740 .evaluate(&batch)?
741 .into_array(batch.num_rows())
742 .expect("Failed to convert to array");
743 let result = as_int32_array(&result)?;
744
745 let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
746
747 assert_eq!(expected, result);
748
749 Ok(())
750 }
751
752 #[test]
753 fn case_with_expr_when_null() -> Result<()> {
754 let batch = case_test_batch()?;
755 let schema = batch.schema();
756
757 let when1 = lit(ScalarValue::Utf8(None));
759 let then1 = lit(0i32);
760 let when2 = col("a", &schema)?;
761 let then2 = lit(123i32);
762 let else_value = lit(999i32);
763
764 let expr = generate_case_when_with_type_coercion(
765 Some(col("a", &schema)?),
766 vec![(when1, then1), (when2, then2)],
767 Some(else_value),
768 schema.as_ref(),
769 )?;
770 let result = expr
771 .evaluate(&batch)?
772 .into_array(batch.num_rows())
773 .expect("Failed to convert to array");
774 let result = as_int32_array(&result)?;
775
776 let expected =
777 &Int32Array::from(vec![Some(123), Some(123), Some(999), Some(123)]);
778
779 assert_eq!(expected, result);
780
781 Ok(())
782 }
783
784 #[test]
785 fn case_without_expr_divide_by_zero() -> Result<()> {
786 let batch = case_test_batch1()?;
787 let schema = batch.schema();
788
789 let when1 = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &batch.schema())?;
791 let then1 = binary(
792 lit(25.0f64),
793 Operator::Divide,
794 cast(col("a", &schema)?, &batch.schema(), Float64)?,
795 &batch.schema(),
796 )?;
797 let x = lit(ScalarValue::Float64(None));
798
799 let expr = generate_case_when_with_type_coercion(
800 None,
801 vec![(when1, then1)],
802 Some(x),
803 schema.as_ref(),
804 )?;
805 let result = expr
806 .evaluate(&batch)?
807 .into_array(batch.num_rows())
808 .expect("Failed to convert to array");
809 let result =
810 as_float64_array(&result).expect("failed to downcast to Float64Array");
811
812 let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
813
814 assert_eq!(expected, result);
815
816 Ok(())
817 }
818
819 fn case_test_batch1() -> Result<RecordBatch> {
820 let schema = Schema::new(vec![
821 Field::new("a", DataType::Int32, true),
822 Field::new("b", DataType::Int32, true),
823 Field::new("c", DataType::Int32, true),
824 ]);
825 let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]);
826 let b = Int32Array::from(vec![Some(3), None, Some(14), Some(7)]);
827 let c = Int32Array::from(vec![Some(0), Some(-3), Some(777), None]);
828 let batch = RecordBatch::try_new(
829 Arc::new(schema),
830 vec![Arc::new(a), Arc::new(b), Arc::new(c)],
831 )?;
832 Ok(batch)
833 }
834
835 #[test]
836 fn case_without_expr_else() -> Result<()> {
837 let batch = case_test_batch()?;
838 let schema = batch.schema();
839
840 let when1 = binary(
842 col("a", &schema)?,
843 Operator::Eq,
844 lit("foo"),
845 &batch.schema(),
846 )?;
847 let then1 = lit(123i32);
848 let when2 = binary(
849 col("a", &schema)?,
850 Operator::Eq,
851 lit("bar"),
852 &batch.schema(),
853 )?;
854 let then2 = lit(456i32);
855 let else_value = lit(999i32);
856
857 let expr = generate_case_when_with_type_coercion(
858 None,
859 vec![(when1, then1), (when2, then2)],
860 Some(else_value),
861 schema.as_ref(),
862 )?;
863 let result = expr
864 .evaluate(&batch)?
865 .into_array(batch.num_rows())
866 .expect("Failed to convert to array");
867 let result = as_int32_array(&result)?;
868
869 let expected =
870 &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
871
872 assert_eq!(expected, result);
873
874 Ok(())
875 }
876
877 #[test]
878 fn case_with_type_cast() -> Result<()> {
879 let batch = case_test_batch()?;
880 let schema = batch.schema();
881
882 let when = binary(
884 col("a", &schema)?,
885 Operator::Eq,
886 lit("foo"),
887 &batch.schema(),
888 )?;
889 let then = lit(123.3f64);
890 let else_value = lit(999i32);
891
892 let expr = generate_case_when_with_type_coercion(
893 None,
894 vec![(when, then)],
895 Some(else_value),
896 schema.as_ref(),
897 )?;
898 let result = expr
899 .evaluate(&batch)?
900 .into_array(batch.num_rows())
901 .expect("Failed to convert to array");
902 let result =
903 as_float64_array(&result).expect("failed to downcast to Float64Array");
904
905 let expected =
906 &Float64Array::from(vec![Some(123.3), Some(999.0), Some(999.0), Some(999.0)]);
907
908 assert_eq!(expected, result);
909
910 Ok(())
911 }
912
913 #[test]
914 fn case_with_matches_and_nulls() -> Result<()> {
915 let batch = case_test_batch_nulls()?;
916 let schema = batch.schema();
917
918 let when = binary(
920 col("load4", &schema)?,
921 Operator::Eq,
922 lit(1.77f64),
923 &batch.schema(),
924 )?;
925 let then = col("load4", &schema)?;
926
927 let expr = generate_case_when_with_type_coercion(
928 None,
929 vec![(when, then)],
930 None,
931 schema.as_ref(),
932 )?;
933 let result = expr
934 .evaluate(&batch)?
935 .into_array(batch.num_rows())
936 .expect("Failed to convert to array");
937 let result =
938 as_float64_array(&result).expect("failed to downcast to Float64Array");
939
940 let expected =
941 &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
942
943 assert_eq!(expected, result);
944
945 Ok(())
946 }
947
948 #[test]
949 fn case_with_scalar_predicate() -> Result<()> {
950 let batch = case_test_batch_nulls()?;
951 let schema = batch.schema();
952
953 let when = lit(true);
955 let then = col("load4", &schema)?;
956 let expr = generate_case_when_with_type_coercion(
957 None,
958 vec![(when, then)],
959 None,
960 schema.as_ref(),
961 )?;
962
963 let result = expr
965 .evaluate(&batch)?
966 .into_array(batch.num_rows())
967 .expect("Failed to convert to array");
968 let result =
969 as_float64_array(&result).expect("failed to downcast to Float64Array");
970 let expected = &Float64Array::from(vec![
971 Some(1.77),
972 None,
973 None,
974 Some(1.78),
975 None,
976 Some(1.77),
977 ]);
978 assert_eq!(expected, result);
979
980 let expected = Float64Array::from(vec![Some(1.1)]);
982 let batch =
983 RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(expected.clone())])?;
984 let result = expr
985 .evaluate(&batch)?
986 .into_array(batch.num_rows())
987 .expect("Failed to convert to array");
988 let result =
989 as_float64_array(&result).expect("failed to downcast to Float64Array");
990 assert_eq!(&expected, result);
991
992 Ok(())
993 }
994
995 #[test]
996 fn case_expr_matches_and_nulls() -> Result<()> {
997 let batch = case_test_batch_nulls()?;
998 let schema = batch.schema();
999
1000 let expr = col("load4", &schema)?;
1002 let when = lit(1.77f64);
1003 let then = col("load4", &schema)?;
1004
1005 let expr = generate_case_when_with_type_coercion(
1006 Some(expr),
1007 vec![(when, then)],
1008 None,
1009 schema.as_ref(),
1010 )?;
1011 let result = expr
1012 .evaluate(&batch)?
1013 .into_array(batch.num_rows())
1014 .expect("Failed to convert to array");
1015 let result =
1016 as_float64_array(&result).expect("failed to downcast to Float64Array");
1017
1018 let expected =
1019 &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
1020
1021 assert_eq!(expected, result);
1022
1023 Ok(())
1024 }
1025
1026 #[test]
1027 fn test_when_null_and_some_cond_else_null() -> Result<()> {
1028 let batch = case_test_batch()?;
1029 let schema = batch.schema();
1030
1031 let when = binary(
1032 Arc::new(Literal::new(ScalarValue::Boolean(None))),
1033 Operator::And,
1034 binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?,
1035 &schema,
1036 )?;
1037 let then = col("a", &schema)?;
1038
1039 let expr = Arc::new(CaseExpr::try_new(None, vec![(when, then)], None)?);
1041 let result = expr
1042 .evaluate(&batch)?
1043 .into_array(batch.num_rows())
1044 .expect("Failed to convert to array");
1045 let result = as_string_array(&result);
1046
1047 assert_eq!(result.logical_null_count(), batch.num_rows());
1049 Ok(())
1050 }
1051
1052 fn case_test_batch() -> Result<RecordBatch> {
1053 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
1054 let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
1055 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
1056 Ok(batch)
1057 }
1058
1059 fn case_test_batch_nulls() -> Result<RecordBatch> {
1062 let load4: Float64Array = vec![
1063 Some(1.77), Some(1.77), Some(1.77), Some(1.78), None, Some(1.77), ]
1070 .into_iter()
1071 .collect();
1072
1073 let null_buffer = Buffer::from([0b00101001u8]);
1074 let load4 = load4
1075 .into_data()
1076 .into_builder()
1077 .null_bit_buffer(Some(null_buffer))
1078 .build()
1079 .unwrap();
1080 let load4: Float64Array = load4.into();
1081
1082 let batch =
1083 RecordBatch::try_from_iter(vec![("load4", Arc::new(load4) as ArrayRef)])?;
1084 Ok(batch)
1085 }
1086
1087 #[test]
1088 fn case_test_incompatible() -> Result<()> {
1089 let batch = case_test_batch()?;
1092 let schema = batch.schema();
1093
1094 let when1 = binary(
1096 col("a", &schema)?,
1097 Operator::Eq,
1098 lit("foo"),
1099 &batch.schema(),
1100 )?;
1101 let then1 = lit(123i32);
1102 let when2 = binary(
1103 col("a", &schema)?,
1104 Operator::Eq,
1105 lit("bar"),
1106 &batch.schema(),
1107 )?;
1108 let then2 = lit(true);
1109
1110 let expr = generate_case_when_with_type_coercion(
1111 None,
1112 vec![(when1, then1), (when2, then2)],
1113 None,
1114 schema.as_ref(),
1115 );
1116 assert!(expr.is_err());
1117
1118 let when1 = binary(
1123 col("a", &schema)?,
1124 Operator::Eq,
1125 lit("foo"),
1126 &batch.schema(),
1127 )?;
1128 let then1 = lit(123i32);
1129 let when2 = binary(
1130 col("a", &schema)?,
1131 Operator::Eq,
1132 lit("bar"),
1133 &batch.schema(),
1134 )?;
1135 let then2 = lit(456i64);
1136 let else_expr = lit(1.23f64);
1137
1138 let expr = generate_case_when_with_type_coercion(
1139 None,
1140 vec![(when1, then1), (when2, then2)],
1141 Some(else_expr),
1142 schema.as_ref(),
1143 );
1144 assert!(expr.is_ok());
1145 let result_type = expr.unwrap().data_type(schema.as_ref())?;
1146 assert_eq!(Float64, result_type);
1147 Ok(())
1148 }
1149
1150 #[test]
1151 fn case_eq() -> Result<()> {
1152 let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1153
1154 let when1 = lit("foo");
1155 let then1 = lit(123i32);
1156 let when2 = lit("bar");
1157 let then2 = lit(456i32);
1158 let else_value = lit(999i32);
1159
1160 let expr1 = generate_case_when_with_type_coercion(
1161 Some(col("a", &schema)?),
1162 vec![
1163 (Arc::clone(&when1), Arc::clone(&then1)),
1164 (Arc::clone(&when2), Arc::clone(&then2)),
1165 ],
1166 Some(Arc::clone(&else_value)),
1167 &schema,
1168 )?;
1169
1170 let expr2 = generate_case_when_with_type_coercion(
1171 Some(col("a", &schema)?),
1172 vec![
1173 (Arc::clone(&when1), Arc::clone(&then1)),
1174 (Arc::clone(&when2), Arc::clone(&then2)),
1175 ],
1176 Some(Arc::clone(&else_value)),
1177 &schema,
1178 )?;
1179
1180 let expr3 = generate_case_when_with_type_coercion(
1181 Some(col("a", &schema)?),
1182 vec![(Arc::clone(&when1), Arc::clone(&then1)), (when2, then2)],
1183 None,
1184 &schema,
1185 )?;
1186
1187 let expr4 = generate_case_when_with_type_coercion(
1188 Some(col("a", &schema)?),
1189 vec![(when1, then1)],
1190 Some(else_value),
1191 &schema,
1192 )?;
1193
1194 assert!(expr1.eq(&expr2));
1195 assert!(expr2.eq(&expr1));
1196
1197 assert!(expr2.ne(&expr3));
1198 assert!(expr3.ne(&expr2));
1199
1200 assert!(expr1.ne(&expr4));
1201 assert!(expr4.ne(&expr1));
1202
1203 Ok(())
1204 }
1205
1206 #[test]
1207 fn case_transform() -> Result<()> {
1208 let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1209
1210 let when1 = lit("foo");
1211 let then1 = lit(123i32);
1212 let when2 = lit("bar");
1213 let then2 = lit(456i32);
1214 let else_value = lit(999i32);
1215
1216 let expr = generate_case_when_with_type_coercion(
1217 Some(col("a", &schema)?),
1218 vec![
1219 (Arc::clone(&when1), Arc::clone(&then1)),
1220 (Arc::clone(&when2), Arc::clone(&then2)),
1221 ],
1222 Some(Arc::clone(&else_value)),
1223 &schema,
1224 )?;
1225
1226 let expr2 = Arc::clone(&expr)
1227 .transform(|e| {
1228 let transformed = match e.as_any().downcast_ref::<Literal>() {
1229 Some(lit_value) => match lit_value.value() {
1230 ScalarValue::Utf8(Some(str_value)) => {
1231 Some(lit(str_value.to_uppercase()))
1232 }
1233 _ => None,
1234 },
1235 _ => None,
1236 };
1237 Ok(if let Some(transformed) = transformed {
1238 Transformed::yes(transformed)
1239 } else {
1240 Transformed::no(e)
1241 })
1242 })
1243 .data()
1244 .unwrap();
1245
1246 let expr3 = Arc::clone(&expr)
1247 .transform_down(|e| {
1248 let transformed = match e.as_any().downcast_ref::<Literal>() {
1249 Some(lit_value) => match lit_value.value() {
1250 ScalarValue::Utf8(Some(str_value)) => {
1251 Some(lit(str_value.to_uppercase()))
1252 }
1253 _ => None,
1254 },
1255 _ => None,
1256 };
1257 Ok(if let Some(transformed) = transformed {
1258 Transformed::yes(transformed)
1259 } else {
1260 Transformed::no(e)
1261 })
1262 })
1263 .data()
1264 .unwrap();
1265
1266 assert!(expr.ne(&expr2));
1267 assert!(expr2.eq(&expr3));
1268
1269 Ok(())
1270 }
1271
1272 #[test]
1273 fn test_column_or_null_specialization() -> Result<()> {
1274 let mut c1 = Int32Builder::new();
1276 let mut c2 = StringBuilder::new();
1277 for i in 0..1000 {
1278 c1.append_value(i);
1279 if i % 7 == 0 {
1280 c2.append_null();
1281 } else {
1282 c2.append_value(format!("string {i}"));
1283 }
1284 }
1285 let c1 = Arc::new(c1.finish());
1286 let c2 = Arc::new(c2.finish());
1287 let schema = Schema::new(vec![
1288 Field::new("c1", DataType::Int32, true),
1289 Field::new("c2", DataType::Utf8, true),
1290 ]);
1291 let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap();
1292
1293 let predicate = Arc::new(BinaryExpr::new(
1295 make_col("c1", 0),
1296 Operator::LtEq,
1297 make_lit_i32(250),
1298 ));
1299 let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?;
1300 assert!(matches!(expr.eval_method, EvalMethod::InfallibleExprOrNull));
1301 match expr.evaluate(&batch)? {
1302 ColumnarValue::Array(array) => {
1303 assert_eq!(1000, array.len());
1304 assert_eq!(785, array.null_count());
1305 }
1306 _ => unreachable!(),
1307 }
1308 Ok(())
1309 }
1310
1311 #[test]
1312 fn test_expr_or_expr_specialization() -> Result<()> {
1313 let batch = case_test_batch1()?;
1314 let schema = batch.schema();
1315 let when = binary(
1316 col("a", &schema)?,
1317 Operator::LtEq,
1318 lit(2i32),
1319 &batch.schema(),
1320 )?;
1321 let then = col("b", &schema)?;
1322 let else_expr = col("c", &schema)?;
1323 let expr = CaseExpr::try_new(None, vec![(when, then)], Some(else_expr))?;
1324 assert!(matches!(
1325 expr.eval_method,
1326 EvalMethod::ExpressionOrExpression
1327 ));
1328 let result = expr
1329 .evaluate(&batch)?
1330 .into_array(batch.num_rows())
1331 .expect("Failed to convert to array");
1332 let result = as_int32_array(&result).expect("failed to downcast to Int32Array");
1333
1334 let expected = &Int32Array::from(vec![Some(3), None, Some(777), None]);
1335
1336 assert_eq!(expected, result);
1337 Ok(())
1338 }
1339
1340 fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
1341 Arc::new(Column::new(name, index))
1342 }
1343
1344 fn make_lit_i32(n: i32) -> Arc<dyn PhysicalExpr> {
1345 Arc::new(Literal::new(ScalarValue::Int32(Some(n))))
1346 }
1347
1348 fn generate_case_when_with_type_coercion(
1349 expr: Option<Arc<dyn PhysicalExpr>>,
1350 when_thens: Vec<WhenThen>,
1351 else_expr: Option<Arc<dyn PhysicalExpr>>,
1352 input_schema: &Schema,
1353 ) -> Result<Arc<dyn PhysicalExpr>> {
1354 let coerce_type =
1355 get_case_common_type(&when_thens, else_expr.clone(), input_schema);
1356 let (when_thens, else_expr) = match coerce_type {
1357 None => plan_err!(
1358 "Can't get a common type for then {when_thens:?} and else {else_expr:?} expression"
1359 ),
1360 Some(data_type) => {
1361 let left = when_thens
1363 .into_iter()
1364 .map(|(when, then)| {
1365 let then = try_cast(then, input_schema, data_type.clone())?;
1366 Ok((when, then))
1367 })
1368 .collect::<Result<Vec<_>>>()?;
1369 let right = match else_expr {
1370 None => None,
1371 Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?),
1372 };
1373
1374 Ok((left, right))
1375 }
1376 }?;
1377 case(expr, when_thens, else_expr)
1378 }
1379
1380 fn get_case_common_type(
1381 when_thens: &[WhenThen],
1382 else_expr: Option<Arc<dyn PhysicalExpr>>,
1383 input_schema: &Schema,
1384 ) -> Option<DataType> {
1385 let thens_type = when_thens
1386 .iter()
1387 .map(|when_then| {
1388 let data_type = &when_then.1.data_type(input_schema).unwrap();
1389 data_type.clone()
1390 })
1391 .collect::<Vec<_>>();
1392 let else_type = match else_expr {
1393 None => {
1394 thens_type[0].clone()
1396 }
1397 Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(),
1398 };
1399 thens_type
1400 .iter()
1401 .try_fold(else_type, |left_type, right_type| {
1402 comparison_coercion(&left_type, right_type)
1405 })
1406 }
1407
1408 #[test]
1409 fn test_fmt_sql() -> Result<()> {
1410 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
1411
1412 let when = binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?;
1414 let then = lit(123.3f64);
1415 let else_value = lit(999i32);
1416
1417 let expr = generate_case_when_with_type_coercion(
1418 None,
1419 vec![(when, then)],
1420 Some(else_value),
1421 &schema,
1422 )?;
1423
1424 let display_string = expr.to_string();
1425 assert_eq!(
1426 display_string,
1427 "CASE WHEN a@0 = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END"
1428 );
1429
1430 let sql_string = fmt_sql(expr.as_ref()).to_string();
1431 assert_eq!(
1432 sql_string,
1433 "CASE WHEN a = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END"
1434 );
1435
1436 Ok(())
1437 }
1438}