Skip to main content

datafusion_physical_expr/expressions/
case.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18mod literal_lookup_table;
19
20use super::{Column, Literal};
21use crate::PhysicalExpr;
22use crate::expressions::{LambdaVariable, lit, try_cast};
23use arrow::array::*;
24use arrow::compute::kernels::zip::zip;
25use arrow::compute::{
26    FilterBuilder, FilterPredicate, is_not_null, not, nullif, prep_null_mask_filter,
27};
28use arrow::datatypes::{DataType, Schema, UInt32Type, UnionMode};
29use arrow::error::ArrowError;
30use datafusion_common::cast::as_boolean_array;
31use datafusion_common::{
32    DataFusionError, Result, ScalarValue, assert_or_internal_err, exec_err,
33    internal_datafusion_err, internal_err,
34};
35use datafusion_expr::ColumnarValue;
36use indexmap::IndexMap;
37use std::borrow::Cow;
38use std::collections::BTreeSet;
39use std::hash::Hash;
40use std::sync::Arc;
41
42use crate::expressions::case::literal_lookup_table::LiteralLookupTable;
43use arrow::compute::kernels::merge::{MergeIndex, merge, merge_n};
44use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
45use datafusion_physical_expr_common::datum::compare_with_eq;
46use datafusion_physical_expr_common::utils::scatter;
47use itertools::Itertools;
48use std::fmt::{Debug, Formatter};
49
50pub(super) type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
51
52#[derive(Debug, Hash, PartialEq, Eq)]
53enum EvalMethod {
54    /// CASE WHEN condition THEN result
55    ///      [WHEN ...]
56    ///      [ELSE result]
57    /// END
58    NoExpression(ProjectedCaseBody),
59    /// CASE expression
60    ///     WHEN value THEN result
61    ///     [WHEN ...]
62    ///     [ELSE result]
63    /// END
64    WithExpression(ProjectedCaseBody),
65    /// This is a specialization for a specific use case where we can take a fast path
66    /// for expressions that are infallible and can be cheaply computed for the entire
67    /// record batch rather than just for the rows where the predicate is true.
68    ///
69    /// CASE WHEN condition THEN infallible_expression [ELSE NULL] END
70    InfallibleExprOrNull,
71    /// This is a specialization for a specific use case where we can take a fast path
72    /// if there is just one when/then pair and both the `then` and `else` expressions
73    /// are literal values
74    /// CASE WHEN condition THEN literal ELSE literal END
75    ScalarOrScalar,
76    /// This is a specialization for a specific use case where we can take a fast path
77    /// if there is just one when/then pair, the `then` is an expression, and `else` is either
78    /// an expression, literal NULL or absent.
79    ///
80    /// In contrast to [`EvalMethod::InfallibleExprOrNull`], this specialization can handle fallible
81    /// `then` expressions.
82    ///
83    /// CASE WHEN condition THEN expression [ELSE expression] END
84    ExpressionOrExpression(ProjectedCaseBody),
85
86    /// This is a specialization for [`EvalMethod::WithExpression`] when the value and results are literals
87    ///
88    /// See [`LiteralLookupTable`] for more details
89    WithExprScalarLookupTable(LiteralLookupTable),
90}
91
92/// Implementing hash so we can use `derive` on [`EvalMethod`].
93///
94/// not implementing actual [`Hash`] as it is not dyn compatible so we cannot implement it for
95/// `dyn` [`literal_lookup_table::WhenLiteralIndexMap`].
96///
97/// So implementing empty hash is still valid as the data is derived from `PhysicalExpr` s which are already hashed
98impl Hash for LiteralLookupTable {
99    fn hash<H: std::hash::Hasher>(&self, _state: &mut H) {}
100}
101
102/// Implementing Equal so we can use `derive` on [`EvalMethod`].
103///
104/// not implementing actual [`PartialEq`] as it is not dyn compatible so we cannot implement it for
105/// `dyn` [`literal_lookup_table::WhenLiteralIndexMap`].
106///
107/// So we always return true as the data is derived from `PhysicalExpr` s which are already compared
108impl PartialEq for LiteralLookupTable {
109    fn eq(&self, _other: &Self) -> bool {
110        true
111    }
112}
113
114impl Eq for LiteralLookupTable {}
115
116/// The body of a CASE expression which consists of an optional base expression, the "when/then"
117/// branches and an optional "else" branch.
118#[derive(Debug, Hash, PartialEq, Eq)]
119struct CaseBody {
120    /// Optional base expression that can be compared to literal values in the "when" expressions
121    expr: Option<Arc<dyn PhysicalExpr>>,
122    /// One or more when/then expressions
123    when_then_expr: Vec<WhenThen>,
124    /// Optional "else" expression
125    else_expr: Option<Arc<dyn PhysicalExpr>>,
126}
127
128impl CaseBody {
129    /// Derives a [ProjectedCaseBody] from this [CaseBody].
130    fn project(&self) -> Result<ProjectedCaseBody> {
131        // Determine the set of columns that are used in all the expressions of the case body.
132        // Use an ordered set so lambda variables continue to be positioned after columns
133        let mut used_column_indices = BTreeSet::<usize>::new();
134        let mut collect_column_indices = |expr: &Arc<dyn PhysicalExpr>| {
135            expr.apply(|expr| {
136                if let Some(column) = expr.downcast_ref::<Column>() {
137                    used_column_indices.insert(column.index());
138                } else if let Some(lambda_variable) =
139                    expr.downcast_ref::<LambdaVariable>()
140                {
141                    used_column_indices.insert(lambda_variable.index());
142                }
143                Ok(TreeNodeRecursion::Continue)
144            })
145            .expect("Closure cannot fail");
146        };
147
148        if let Some(e) = &self.expr {
149            collect_column_indices(e);
150        }
151        self.when_then_expr.iter().for_each(|(w, t)| {
152            collect_column_indices(w);
153            collect_column_indices(t);
154        });
155        if let Some(e) = &self.else_expr {
156            collect_column_indices(e);
157        }
158
159        // Construct a mapping from the original column index to the projected column index.
160        let column_index_map = used_column_indices
161            .iter()
162            .enumerate()
163            .map(|(projected, original)| (*original, projected))
164            .collect::<IndexMap<usize, usize>>();
165
166        // Construct the projected body by rewriting each expression from the original body
167        // using the column index mapping.
168        let project = |expr: &Arc<dyn PhysicalExpr>| -> Result<Arc<dyn PhysicalExpr>> {
169            Arc::clone(expr)
170                .transform_down(|e| {
171                    if let Some(column) = e.downcast_ref::<Column>() {
172                        let original = column.index();
173                        let projected = *column_index_map.get(&original).unwrap();
174                        if projected != original {
175                            return Ok(Transformed::yes(Arc::new(Column::new(
176                                column.name(),
177                                projected,
178                            ))));
179                        }
180                    } else if let Some(lambda_variable) =
181                        e.downcast_ref::<LambdaVariable>()
182                    {
183                        let original = lambda_variable.index();
184                        let projected = *column_index_map.get(&original).unwrap();
185                        if projected != original {
186                            return Ok(Transformed::yes(Arc::new(LambdaVariable::new(
187                                projected,
188                                Arc::clone(lambda_variable.field()),
189                            ))));
190                        }
191                    }
192                    Ok(Transformed::no(e))
193                })
194                .map(|t| t.data)
195        };
196
197        let projected_body = CaseBody {
198            expr: self.expr.as_ref().map(project).transpose()?,
199            when_then_expr: self
200                .when_then_expr
201                .iter()
202                .map(|(e, t)| Ok((project(e)?, project(t)?)))
203                .collect::<Result<Vec<_>>>()?,
204            else_expr: self.else_expr.as_ref().map(project).transpose()?,
205        };
206
207        // Construct the projection vector
208        let projection = column_index_map
209            .iter()
210            .sorted_by_key(|(_, v)| **v)
211            .map(|(k, _)| *k)
212            .collect::<Vec<_>>();
213
214        Ok(ProjectedCaseBody {
215            projection,
216            body: projected_body,
217        })
218    }
219}
220
221/// A derived case body that can be used to evaluate a case expression after projecting
222/// record batches using a projection vector.
223///
224/// This is used to avoid filtering columns that are not used in the
225/// input `RecordBatch` when progressively evaluating a `CASE` expression's
226/// remainder batches. Filtering these columns is wasteful since for a record
227/// batch of `n` rows, filtering requires at worst a copy of `n - 1` values
228/// per array. If these filtered values will never be accessed, the time spent
229/// producing them is better avoided.
230///
231/// For example, if we are evaluating the following case expression that
232/// only references columns B and D:
233///
234/// ```sql
235/// SELECT CASE WHEN B > 10 THEN D ELSE NULL END FROM (VALUES (...)) T(A, B, C, D)
236/// ```
237///
238/// Of the 4 input columns `[A, B, C, D]`, the `CASE` expression only access `B` and `D`.
239/// Filtering `A` and `C` would be unnecessary and wasteful.
240///
241/// If we only retain columns `B` and `D` using `RecordBatch::project` and the projection vector
242/// `[1, 3]`, the indices of these two columns will change to `[0, 1]`. To evaluate the
243/// case expression, it will need to be rewritten from `CASE WHEN B@1 > 10 THEN D@3 ELSE NULL END`
244/// to `CASE WHEN B@0 > 10 THEN D@1 ELSE NULL END`.
245///
246/// The projection vector and the rewritten expression (which only differs from the original in
247/// column reference indices) are held in a `ProjectedCaseBody`.
248#[derive(Debug, Hash, PartialEq, Eq)]
249struct ProjectedCaseBody {
250    projection: Vec<usize>,
251    body: CaseBody,
252}
253
254/// The CASE expression is similar to a series of nested if/else and there are two forms that
255/// can be used. The first form consists of a series of boolean "when" expressions with
256/// corresponding "then" expressions, and an optional "else" expression.
257///
258/// CASE WHEN condition THEN result
259///      [WHEN ...]
260///      [ELSE result]
261/// END
262///
263/// The second form uses a base expression and then a series of "when" clauses that match on a
264/// literal value.
265///
266/// CASE expression
267///     WHEN value THEN result
268///     [WHEN ...]
269///     [ELSE result]
270/// END
271#[derive(Debug)]
272pub struct CaseExpr {
273    /// The case expression body
274    body: CaseBody,
275    /// Evaluation method to use
276    eval_method: EvalMethod,
277}
278
279// eval_method is functionally derived from body, so excluding it from
280// Hash/Eq avoids redundantly hashing the expression tree twice. For
281// nested CASE chains this prevents exponential blowup (see #22173).
282impl Hash for CaseExpr {
283    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
284        self.body.hash(state);
285    }
286}
287
288impl PartialEq for CaseExpr {
289    fn eq(&self, other: &Self) -> bool {
290        self.body == other.body
291    }
292}
293
294impl Eq for CaseExpr {}
295
296impl std::fmt::Display for CaseExpr {
297    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
298        write!(f, "CASE ")?;
299        if let Some(e) = &self.body.expr {
300            write!(f, "{e} ")?;
301        }
302        for (w, t) in &self.body.when_then_expr {
303            write!(f, "WHEN {w} THEN {t} ")?;
304        }
305        if let Some(e) = &self.body.else_expr {
306            write!(f, "ELSE {e} ")?;
307        }
308        write!(f, "END")
309    }
310}
311
312/// This is a specialization for a specific use case where we can take a fast path
313/// for expressions that are infallible and can be cheaply computed for the entire
314/// record batch rather than just for the rows where the predicate is true. For now,
315/// this is limited to use with Column expressions but could potentially be used for other
316/// expressions in the future
317fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) -> bool {
318    expr.is::<Column>()
319}
320
321/// Creates a [FilterPredicate] from a boolean array.
322fn create_filter(predicate: &BooleanArray, optimize: bool) -> FilterPredicate {
323    let mut filter_builder = FilterBuilder::new(predicate);
324    if optimize {
325        // Always optimize the filter since we use them multiple times.
326        filter_builder = filter_builder.optimize();
327    }
328    filter_builder.build()
329}
330
331fn multiple_arrays(data_type: &DataType) -> bool {
332    match data_type {
333        DataType::Struct(fields) => {
334            fields.len() > 1
335                || fields.len() == 1 && multiple_arrays(fields[0].data_type())
336        }
337        DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
338        _ => false,
339    }
340}
341
342// This should be removed when https://github.com/apache/arrow-rs/pull/8693
343// is merged and becomes available.
344fn filter_record_batch(
345    record_batch: &RecordBatch,
346    filter: &FilterPredicate,
347) -> std::result::Result<RecordBatch, ArrowError> {
348    let filtered_columns = record_batch
349        .columns()
350        .iter()
351        .map(|a| filter_array(a, filter))
352        .collect::<std::result::Result<Vec<_>, _>>()?;
353    // SAFETY: since we start from a valid RecordBatch, there's no need to revalidate the schema
354    // since the set of columns has not changed.
355    // The input column arrays all had the same length (since they're coming from a valid RecordBatch)
356    // and the filtering them with the same filter will produces a new set of arrays with identical
357    // lengths.
358    unsafe {
359        Ok(RecordBatch::new_unchecked(
360            record_batch.schema(),
361            filtered_columns,
362            filter.count(),
363        ))
364    }
365}
366
367// This function exists purely to be able to use the same call style
368// for `filter_record_batch` and `filter_array` at the point of use.
369// When https://github.com/apache/arrow-rs/pull/8693 is available, replace
370// both with method calls on `FilterPredicate`.
371#[inline(always)]
372fn filter_array(
373    array: &dyn Array,
374    filter: &FilterPredicate,
375) -> std::result::Result<ArrayRef, ArrowError> {
376    filter.filter(array)
377}
378
379/// An index into the partial results array that's more compact than `usize`.
380///
381/// `u32::MAX` is reserved as a special 'none' value. This is used instead of
382/// `Option` to keep the array of indices as compact as possible.
383#[derive(Copy, Clone, PartialEq, Eq)]
384struct PartialResultIndex {
385    index: u32,
386}
387
388const NONE_VALUE: u32 = u32::MAX;
389
390impl PartialResultIndex {
391    /// Returns the 'none' placeholder value.
392    fn none() -> Self {
393        Self { index: NONE_VALUE }
394    }
395
396    fn zero() -> Self {
397        Self { index: 0 }
398    }
399
400    /// Creates a new partial result index.
401    ///
402    /// If the provided value is greater than or equal to `u32::MAX`
403    /// an error will be returned.
404    fn try_new(index: usize) -> Result<Self> {
405        let Ok(index) = u32::try_from(index) else {
406            return internal_err!("Partial result index exceeds limit");
407        };
408
409        assert_or_internal_err!(
410            index != NONE_VALUE,
411            "Partial result index exceeds limit"
412        );
413
414        Ok(Self { index })
415    }
416
417    /// Determines if this index is the 'none' placeholder value or not.
418    fn is_none(&self) -> bool {
419        self.index == NONE_VALUE
420    }
421}
422
423impl MergeIndex for PartialResultIndex {
424    /// Returns `Some(index)` if this value is not the 'none' placeholder, `None` otherwise.
425    fn index(&self) -> Option<usize> {
426        if self.is_none() {
427            None
428        } else {
429            Some(self.index as usize)
430        }
431    }
432}
433
434impl Debug for PartialResultIndex {
435    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
436        if self.is_none() {
437            write!(f, "null")
438        } else {
439            write!(f, "{}", self.index)
440        }
441    }
442}
443
444enum ResultState {
445    /// The final result is an array containing only null values.
446    Empty,
447    /// The final result needs to be computed by merging the data in `arrays`.
448    Partial {
449        // A `Vec` of partial results that should be merged.
450        // `partial_result_indices` contains indexes into this vec.
451        arrays: Vec<ArrayRef>,
452        // Indicates per result row from which array in `partial_results` a value should be taken.
453        indices: Vec<PartialResultIndex>,
454    },
455    /// A single branch matched all input rows. When creating the final result, no further merging
456    /// of partial results is necessary.
457    Complete(ColumnarValue),
458}
459
460/// A builder for constructing result arrays for CASE expressions.
461///
462/// Rather than building a monolithic array containing all results, it maintains a set of
463/// partial result arrays and a mapping that indicates for each row which partial array
464/// contains the result value for that row.
465///
466/// On finish(), the builder will merge all partial results into a single array if necessary.
467/// If all rows evaluated to the same array, that array can be returned directly without
468/// any merging overhead.
469struct ResultBuilder {
470    data_type: DataType,
471    /// The number of rows in the final result.
472    row_count: usize,
473    state: ResultState,
474}
475
476impl ResultBuilder {
477    /// Creates a new ResultBuilder that will produce arrays of the given data type.
478    ///
479    /// The `row_count` parameter indicates the number of rows in the final result.
480    fn new(data_type: &DataType, row_count: usize) -> Self {
481        Self {
482            data_type: data_type.clone(),
483            row_count,
484            state: ResultState::Empty,
485        }
486    }
487
488    /// Adds a result for one branch of the case expression.
489    ///
490    /// `row_indices` should be a [UInt32Array] containing [RecordBatch] relative row indices
491    /// for which `value` contains result values.
492    ///
493    /// If `value` is a scalar, the scalar value will be used as the value for each row in `row_indices`.
494    ///
495    /// If `value` is an array, the values from the array and the indices from `row_indices` will be
496    /// processed pairwise. The lengths of `value` and `row_indices` must match.
497    ///
498    /// The diagram below shows a situation where a when expression matched rows 1 and 4 of the
499    /// record batch. The then expression produced the value array `[A, D]`.
500    /// After adding this result, the result array will have been added to `partial arrays` and
501    /// `partial indices` will have been updated at indexes `1` and `4`.
502    ///
503    /// ```text
504    ///  ┌─────────┐     ┌─────────┐┌───────────┐                            ┌─────────┐┌───────────┐
505    ///  │    C    │     │ 0: None ││┌ 0 ──────┐│                            │ 0: None ││┌ 0 ──────┐│
506    ///  ├─────────┤     ├─────────┤││    A    ││                            ├─────────┤││    A    ││
507    ///  │    D    │     │ 1: None ││└─────────┘│                            │ 1:  2   ││└─────────┘│
508    ///  └─────────┘     ├─────────┤│┌ 1 ──────┐│   add_branch_result(       ├─────────┤│┌ 1 ──────┐│
509    ///   matching       │ 2:  0   │││    B    ││     row indices,           │ 2:  0   │││    B    ││
510    /// 'then' values    ├─────────┤│└─────────┘│     value                  ├─────────┤│└─────────┘│
511    ///                  │ 3: None ││           │   )                        │ 3: None ││┌ 2 ──────┐│
512    ///  ┌─────────┐     ├─────────┤│           │ ─────────────────────────▶ ├─────────┤││    C    ││
513    ///  │    1    │     │ 4: None ││           │                            │ 4:  2   ││├─────────┤│
514    ///  ├─────────┤     ├─────────┤│           │                            ├─────────┤││    D    ││
515    ///  │    4    │     │ 5:  1   ││           │                            │ 5:  1   ││└─────────┘│
516    ///  └─────────┘     └─────────┘└───────────┘                            └─────────┘└───────────┘
517    /// row indices        partial     partial                                 partial     partial
518    ///                    indices     arrays                                  indices     arrays
519    /// ```
520    fn add_branch_result(
521        &mut self,
522        row_indices: &ArrayRef,
523        value: ColumnarValue,
524    ) -> Result<()> {
525        match value {
526            ColumnarValue::Array(a) => {
527                if a.len() != row_indices.len() {
528                    internal_err!("Array length must match row indices length")
529                } else if row_indices.len() == self.row_count {
530                    self.set_complete_result(ColumnarValue::Array(a))
531                } else {
532                    self.add_partial_result(row_indices, a)
533                }
534            }
535            ColumnarValue::Scalar(s) => {
536                if row_indices.len() == self.row_count {
537                    self.set_complete_result(ColumnarValue::Scalar(s))
538                } else {
539                    self.add_partial_result(
540                        row_indices,
541                        s.to_array_of_size(row_indices.len())?,
542                    )
543                }
544            }
545        }
546    }
547
548    /// Adds a partial result array.
549    ///
550    /// This method adds the given array data as a partial result and updates the index mapping
551    /// to indicate that the specified rows should take their values from this array.
552    /// The partial results will be merged into a single array when finish() is called.
553    fn add_partial_result(
554        &mut self,
555        row_indices: &ArrayRef,
556        row_values: ArrayRef,
557    ) -> Result<()> {
558        assert_or_internal_err!(
559            row_indices.null_count() == 0,
560            "Row indices must not contain nulls"
561        );
562
563        match &mut self.state {
564            ResultState::Empty => {
565                let array_index = PartialResultIndex::zero();
566                let mut indices = vec![PartialResultIndex::none(); self.row_count];
567                for row_ix in row_indices.as_primitive::<UInt32Type>().values().iter() {
568                    indices[*row_ix as usize] = array_index;
569                }
570
571                self.state = ResultState::Partial {
572                    arrays: vec![row_values],
573                    indices,
574                };
575
576                Ok(())
577            }
578            ResultState::Partial { arrays, indices } => {
579                let array_index = PartialResultIndex::try_new(arrays.len())?;
580
581                arrays.push(row_values);
582
583                for row_ix in row_indices.as_primitive::<UInt32Type>().values().iter() {
584                    // This is check is only active for debug config because the callers of this method,
585                    // `case_when_with_expr` and `case_when_no_expr`, already ensure that
586                    // they only calculate a value for each row at most once.
587                    #[cfg(debug_assertions)]
588                    assert_or_internal_err!(
589                        indices[*row_ix as usize].is_none(),
590                        "Duplicate value for row {}",
591                        *row_ix
592                    );
593
594                    indices[*row_ix as usize] = array_index;
595                }
596                Ok(())
597            }
598            ResultState::Complete(_) => internal_err!(
599                "Cannot add a partial result when complete result is already set"
600            ),
601        }
602    }
603
604    /// Sets a result that applies to all rows.
605    ///
606    /// This is an optimization for cases where all rows evaluate to the same result.
607    /// When a complete result is set, the builder will return it directly from finish()
608    /// without any merging overhead.
609    fn set_complete_result(&mut self, value: ColumnarValue) -> Result<()> {
610        match &self.state {
611            ResultState::Empty => {
612                self.state = ResultState::Complete(value);
613                Ok(())
614            }
615            ResultState::Partial { .. } => {
616                internal_err!(
617                    "Cannot set a complete result when there are already partial results"
618                )
619            }
620            ResultState::Complete(_) => internal_err!("Complete result already set"),
621        }
622    }
623
624    /// Finishes building the result and returns the final array.
625    fn finish(self) -> Result<ColumnarValue> {
626        match self.state {
627            ResultState::Empty => {
628                // No complete result and no partial results.
629                // This can happen for case expressions with no else branch where no rows
630                // matched.
631                Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
632                    &self.data_type,
633                )?))
634            }
635            ResultState::Partial { arrays, indices } => {
636                // Merge partial results into a single array.
637                let array_refs = arrays.iter().map(|a| a.as_ref()).collect::<Vec<_>>();
638                Ok(ColumnarValue::Array(merge_n(&array_refs, &indices)?))
639            }
640            ResultState::Complete(v) => {
641                // If we have a complete result, we can just return it.
642                Ok(v)
643            }
644        }
645    }
646}
647
648impl CaseExpr {
649    /// Create a new CASE WHEN expression
650    pub fn try_new(
651        expr: Option<Arc<dyn PhysicalExpr>>,
652        when_then_expr: Vec<WhenThen>,
653        else_expr: Option<Arc<dyn PhysicalExpr>>,
654    ) -> Result<Self> {
655        // normalize null literals to None in the else_expr (this already happens
656        // during SQL planning, but not necessarily for other use cases)
657        let else_expr = match &else_expr {
658            Some(e) => match e.downcast_ref::<Literal>() {
659                Some(lit) if lit.value().is_null() => None,
660                _ => else_expr,
661            },
662            _ => else_expr,
663        };
664
665        if when_then_expr.is_empty() {
666            return exec_err!("There must be at least one WHEN clause");
667        }
668
669        let body = CaseBody {
670            expr,
671            when_then_expr,
672            else_expr,
673        };
674
675        let eval_method = Self::find_best_eval_method(&body)?;
676
677        Ok(Self { body, eval_method })
678    }
679
680    fn find_best_eval_method(body: &CaseBody) -> Result<EvalMethod> {
681        if body.expr.is_some() {
682            if let Some(mapping) = LiteralLookupTable::maybe_new(body) {
683                return Ok(EvalMethod::WithExprScalarLookupTable(mapping));
684            }
685
686            return Ok(EvalMethod::WithExpression(body.project()?));
687        }
688
689        Ok(
690            if body.when_then_expr.len() == 1
691                && is_cheap_and_infallible(&(body.when_then_expr[0].1))
692                && body.else_expr.is_none()
693            {
694                EvalMethod::InfallibleExprOrNull
695            } else if body.when_then_expr.len() == 1
696                && body.when_then_expr[0].1.is::<Literal>()
697                && body.else_expr.is_some()
698                && body.else_expr.as_ref().unwrap().is::<Literal>()
699            {
700                EvalMethod::ScalarOrScalar
701            } else if body.when_then_expr.len() == 1 {
702                EvalMethod::ExpressionOrExpression(body.project()?)
703            } else {
704                EvalMethod::NoExpression(body.project()?)
705            },
706        )
707    }
708
709    /// Optional base expression that can be compared to literal values in the "when" expressions
710    pub fn expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
711        self.body.expr.as_ref()
712    }
713
714    /// One or more when/then expressions
715    pub fn when_then_expr(&self) -> &[WhenThen] {
716        &self.body.when_then_expr
717    }
718
719    /// Optional "else" expression
720    pub fn else_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
721        self.body.else_expr.as_ref()
722    }
723}
724
725impl CaseBody {
726    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
727        // since all then results have the same data type, we can choose any one as the
728        // return data type except for the null.
729        let mut data_type = DataType::Null;
730        for i in 0..self.when_then_expr.len() {
731            data_type = self.when_then_expr[i].1.data_type(input_schema)?;
732            if !data_type.equals_datatype(&DataType::Null) {
733                break;
734            }
735        }
736        // if all then results are null, we use data type of else expr instead if possible.
737        if data_type.equals_datatype(&DataType::Null)
738            && let Some(e) = &self.else_expr
739        {
740            data_type = e.data_type(input_schema)?;
741        }
742
743        Ok(data_type)
744    }
745
746    /// See [CaseExpr::case_when_with_expr].
747    fn case_when_with_expr(
748        &self,
749        batch: &RecordBatch,
750        return_type: &DataType,
751    ) -> Result<ColumnarValue> {
752        let mut result_builder = ResultBuilder::new(return_type, batch.num_rows());
753
754        // `remainder_rows` contains the indices of the rows that need to be evaluated
755        let mut remainder_rows: ArrayRef =
756            Arc::new(UInt32Array::from_iter_values(0..batch.num_rows() as u32));
757        // `remainder_batch` contains the rows themselves that need to be evaluated
758        let mut remainder_batch = Cow::Borrowed(batch);
759
760        // evaluate the base expression
761        let mut base_values = self
762            .expr
763            .as_ref()
764            .unwrap()
765            .evaluate(batch)?
766            .into_array(batch.num_rows())?;
767
768        // Fill in a result value already for rows where the base expression value is null
769        // Since each when expression is tested against the base expression using the equality
770        // operator, null base values can never match any when expression. `x = NULL` is falsy,
771        // for all possible values of `x`.
772        let base_null_count = base_values.logical_null_count();
773        if base_null_count > 0 {
774            // Use `is_not_null` since this is a cheap clone of the null buffer from 'base_value'.
775            // We already checked there are nulls, so we can be sure a new buffer will not be
776            // created.
777            let base_not_nulls = is_not_null(base_values.as_ref())?;
778            let base_all_null = base_null_count == remainder_batch.num_rows();
779
780            // If there is an else expression, use that as the default value for the null rows
781            // Otherwise the default `null` value from the result builder will be used.
782            if let Some(e) = &self.else_expr {
783                let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
784
785                if base_all_null {
786                    // All base values were null, so no need to filter
787                    let nulls_value = expr.evaluate(&remainder_batch)?;
788                    result_builder.add_branch_result(&remainder_rows, nulls_value)?;
789                } else {
790                    // Filter out the null rows and evaluate the else expression for those
791                    let nulls_filter = create_filter(&not(&base_not_nulls)?, true);
792                    let nulls_batch =
793                        filter_record_batch(&remainder_batch, &nulls_filter)?;
794                    let nulls_rows = filter_array(&remainder_rows, &nulls_filter)?;
795                    let nulls_value = expr.evaluate(&nulls_batch)?;
796                    result_builder.add_branch_result(&nulls_rows, nulls_value)?;
797                }
798            }
799
800            // All base values are null, so we can return early
801            if base_all_null {
802                return result_builder.finish();
803            }
804
805            // Remove the null rows from the remainder batch
806            let not_null_filter = create_filter(&base_not_nulls, true);
807            remainder_batch =
808                Cow::Owned(filter_record_batch(&remainder_batch, &not_null_filter)?);
809            remainder_rows = filter_array(&remainder_rows, &not_null_filter)?;
810            base_values = filter_array(&base_values, &not_null_filter)?;
811        }
812
813        // The types of case and when expressions will be coerced to match.
814        // We only need to check if the base_value is nested.
815        let base_value_is_nested = base_values.data_type().is_nested();
816
817        for i in 0..self.when_then_expr.len() {
818            // Evaluate the 'when' predicate for the remainder batch
819            // This results in a boolean array with the same length as the remaining number of rows
820            let when_expr = &self.when_then_expr[i].0;
821            let when_value = match when_expr.evaluate(&remainder_batch)? {
822                ColumnarValue::Array(a) => {
823                    compare_with_eq(&a, &base_values, base_value_is_nested)
824                }
825                ColumnarValue::Scalar(s) => {
826                    compare_with_eq(&s.to_scalar()?, &base_values, base_value_is_nested)
827                }
828            }?;
829
830            // If the 'when' predicate did not match any rows, continue to the next branch immediately.
831            // Only counts valid slots that are true (masked-null predicate slots are ignored),
832            // so no `prep_null_mask_filter` needed here.
833            if !when_value.has_true() {
834                continue;
835            }
836
837            // If the 'when' predicate matched all remaining rows, there is no need to filter
838            if when_value.null_count() == 0 && !when_value.has_false() {
839                let then_expression = &self.when_then_expr[i].1;
840                let then_value = then_expression.evaluate(&remainder_batch)?;
841                result_builder.add_branch_result(&remainder_rows, then_value)?;
842                return result_builder.finish();
843            }
844
845            // Filter the remainder batch based on the 'when' value
846            // This results in a batch containing only the rows that need to be evaluated
847            // for the current branch
848            // Still no need to call `prep_null_mask_filter` since `create_filter` will already do
849            // this unconditionally.
850            let then_filter = create_filter(&when_value, true);
851            let then_batch = filter_record_batch(&remainder_batch, &then_filter)?;
852            let then_rows = filter_array(&remainder_rows, &then_filter)?;
853
854            let then_expression = &self.when_then_expr[i].1;
855            let then_value = then_expression.evaluate(&then_batch)?;
856            result_builder.add_branch_result(&then_rows, then_value)?;
857
858            // If this is the last 'when' branch and there is no 'else' expression, there's no
859            // point in calculating the remaining rows.
860            if self.else_expr.is_none() && i == self.when_then_expr.len() - 1 {
861                return result_builder.finish();
862            }
863
864            // Prepare the next when branch (or the else branch)
865            let next_selection = match when_value.null_count() {
866                0 => not(&when_value),
867                _ => {
868                    // `prep_null_mask_filter` is required to ensure the not operation treats nulls
869                    // as false
870                    not(&prep_null_mask_filter(&when_value))
871                }
872            }?;
873            let next_filter = create_filter(&next_selection, true);
874            remainder_batch =
875                Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?);
876            remainder_rows = filter_array(&remainder_rows, &next_filter)?;
877            base_values = filter_array(&base_values, &next_filter)?;
878        }
879
880        // If we reached this point, some rows were left unmatched.
881        // Check if those need to be evaluated using the 'else' expression.
882        if let Some(e) = &self.else_expr {
883            // keep `else_expr`'s data type and return type consistent
884            let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
885            let else_value = expr.evaluate(&remainder_batch)?;
886            result_builder.add_branch_result(&remainder_rows, else_value)?;
887        }
888
889        result_builder.finish()
890    }
891
892    /// See [CaseExpr::case_when_no_expr].
893    fn case_when_no_expr(
894        &self,
895        batch: &RecordBatch,
896        return_type: &DataType,
897    ) -> Result<ColumnarValue> {
898        let mut result_builder = ResultBuilder::new(return_type, batch.num_rows());
899
900        // `remainder_rows` contains the indices of the rows that need to be evaluated
901        let mut remainder_rows: ArrayRef =
902            Arc::new(UInt32Array::from_iter(0..batch.num_rows() as u32));
903        // `remainder_batch` contains the rows themselves that need to be evaluated
904        let mut remainder_batch = Cow::Borrowed(batch);
905
906        for i in 0..self.when_then_expr.len() {
907            // Evaluate the 'when' predicate for the remainder batch
908            // This results in a boolean array with the same length as the remaining number of rows
909            let when_predicate = &self.when_then_expr[i].0;
910            let when_value = when_predicate
911                .evaluate(&remainder_batch)?
912                .into_array(remainder_batch.num_rows())?;
913            let when_value = as_boolean_array(&when_value).map_err(|_| {
914                internal_datafusion_err!("WHEN expression did not return a BooleanArray")
915            })?;
916
917            // If the 'when' predicate did not match any rows, continue to the next branch immediately.
918            // Only counts valid slots that are true (masked-null predicate slots are ignored)
919            // so no `prep_null_mask_filter` needed here.
920            if !when_value.has_true() {
921                continue;
922            }
923
924            // If the 'when' predicate matched all remaining rows, there is no need to filter
925            if when_value.null_count() == 0 && !when_value.has_false() {
926                let then_expression = &self.when_then_expr[i].1;
927                let then_value = then_expression.evaluate(&remainder_batch)?;
928                result_builder.add_branch_result(&remainder_rows, then_value)?;
929                return result_builder.finish();
930            }
931
932            // Filter the remainder batch based on the 'when' value
933            // This results in a batch containing only the rows that need to be evaluated
934            // for the current branch
935            // Still no need to call `prep_null_mask_filter` since `create_filter` will already do
936            // this unconditionally.
937            let then_filter = create_filter(when_value, true);
938            let then_batch = filter_record_batch(&remainder_batch, &then_filter)?;
939            let then_rows = filter_array(&remainder_rows, &then_filter)?;
940
941            let then_expression = &self.when_then_expr[i].1;
942            let then_value = then_expression.evaluate(&then_batch)?;
943            result_builder.add_branch_result(&then_rows, then_value)?;
944
945            // If this is the last 'when' branch and there is no 'else' expression, there's no
946            // point in calculating the remaining rows.
947            if self.else_expr.is_none() && i == self.when_then_expr.len() - 1 {
948                return result_builder.finish();
949            }
950
951            // Prepare the next when branch (or the else branch)
952            let next_selection = match when_value.null_count() {
953                0 => not(when_value),
954                _ => {
955                    // `prep_null_mask_filter` is required to ensure the not operation treats nulls
956                    // as false
957                    not(&prep_null_mask_filter(when_value))
958                }
959            }?;
960            let next_filter = create_filter(&next_selection, true);
961            remainder_batch =
962                Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?);
963            remainder_rows = filter_array(&remainder_rows, &next_filter)?;
964        }
965
966        // If we reached this point, some rows were left unmatched.
967        // Check if those need to be evaluated using the 'else' expression.
968        if let Some(e) = &self.else_expr {
969            // keep `else_expr`'s data type and return type consistent
970            let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
971            let else_value = expr.evaluate(&remainder_batch)?;
972            result_builder.add_branch_result(&remainder_rows, else_value)?;
973        }
974
975        result_builder.finish()
976    }
977
978    /// See [CaseExpr::expr_or_expr].
979    fn expr_or_expr(
980        &self,
981        batch: &RecordBatch,
982        when_value: &BooleanArray,
983    ) -> Result<ColumnarValue> {
984        let when_value = match when_value.null_count() {
985            0 => Cow::Borrowed(when_value),
986            _ => {
987                // `prep_null_mask_filter` is required to ensure null is treated as false
988                Cow::Owned(prep_null_mask_filter(when_value))
989            }
990        };
991
992        let optimize_filter = batch.num_columns() > 1
993            || (batch.num_columns() == 1 && multiple_arrays(batch.column(0).data_type()));
994
995        let when_filter = create_filter(&when_value, optimize_filter);
996        let then_batch = filter_record_batch(batch, &when_filter)?;
997        let then_value = self.when_then_expr[0].1.evaluate(&then_batch)?;
998
999        match &self.else_expr {
1000            None => {
1001                let then_array = then_value.to_array(when_value.true_count())?;
1002                scatter(&when_value, then_array.as_ref()).map(ColumnarValue::Array)
1003            }
1004            Some(else_expr) => {
1005                let else_selection = not(&when_value)?;
1006                let else_filter = create_filter(&else_selection, optimize_filter);
1007                let else_batch = filter_record_batch(batch, &else_filter)?;
1008
1009                // keep `else_expr`'s data type and return type consistent
1010                let return_type = self.data_type(&batch.schema())?;
1011                let else_expr =
1012                    try_cast(Arc::clone(else_expr), &batch.schema(), return_type.clone())
1013                        .unwrap_or_else(|_| Arc::clone(else_expr));
1014
1015                let else_value = else_expr.evaluate(&else_batch)?;
1016
1017                Ok(ColumnarValue::Array(match (then_value, else_value) {
1018                    (ColumnarValue::Array(t), ColumnarValue::Array(e)) => {
1019                        merge(&when_value, &t, &e)
1020                    }
1021                    (ColumnarValue::Scalar(t), ColumnarValue::Array(e)) => {
1022                        merge(&when_value, &t.to_scalar()?, &e)
1023                    }
1024                    (ColumnarValue::Array(t), ColumnarValue::Scalar(e)) => {
1025                        merge(&when_value, &t, &e.to_scalar()?)
1026                    }
1027                    (ColumnarValue::Scalar(t), ColumnarValue::Scalar(e)) => {
1028                        merge(&when_value, &t.to_scalar()?, &e.to_scalar()?)
1029                    }
1030                }?))
1031            }
1032        }
1033    }
1034}
1035
1036impl CaseExpr {
1037    /// This function evaluates the form of CASE that matches an expression to fixed values.
1038    ///
1039    /// CASE expression
1040    ///     WHEN value THEN result
1041    ///     [WHEN ...]
1042    ///     [ELSE result]
1043    /// END
1044    fn case_when_with_expr(
1045        &self,
1046        batch: &RecordBatch,
1047        projected: &ProjectedCaseBody,
1048    ) -> Result<ColumnarValue> {
1049        let return_type = self.data_type(&batch.schema())?;
1050        // projected.projection may include indexes of lambda variables not available on this batch
1051        let projection = projected
1052            .projection
1053            .iter()
1054            .copied()
1055            .filter(|index| *index < batch.num_columns())
1056            .collect::<Vec<_>>();
1057        if projection.len() < batch.num_columns() {
1058            let projected_batch = batch.project(&projection)?;
1059            projected
1060                .body
1061                .case_when_with_expr(&projected_batch, &return_type)
1062        } else {
1063            self.body.case_when_with_expr(batch, &return_type)
1064        }
1065    }
1066
1067    /// This function evaluates the form of CASE where each WHEN expression is a boolean
1068    /// expression.
1069    ///
1070    /// CASE WHEN condition THEN result
1071    ///      [WHEN ...]
1072    ///      [ELSE result]
1073    /// END
1074    fn case_when_no_expr(
1075        &self,
1076        batch: &RecordBatch,
1077        projected: &ProjectedCaseBody,
1078    ) -> Result<ColumnarValue> {
1079        let return_type = self.data_type(&batch.schema())?;
1080        // projected.projection may include indexes of lambda variables not available on this batch
1081        let projection = projected
1082            .projection
1083            .iter()
1084            .copied()
1085            .filter(|index| *index < batch.num_columns())
1086            .collect::<Vec<_>>();
1087        if projection.len() < batch.num_columns() {
1088            let projected_batch = batch.project(&projection)?;
1089            projected
1090                .body
1091                .case_when_no_expr(&projected_batch, &return_type)
1092        } else {
1093            self.body.case_when_no_expr(batch, &return_type)
1094        }
1095    }
1096
1097    /// This function evaluates the specialized case of:
1098    ///
1099    /// CASE WHEN condition THEN column
1100    ///      [ELSE NULL]
1101    /// END
1102    ///
1103    /// Note that this function is only safe to use for "then" expressions
1104    /// that are infallible because the expression will be evaluated for all
1105    /// rows in the input batch.
1106    fn case_column_or_null(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
1107        let when_expr = &self.body.when_then_expr[0].0;
1108        let then_expr = &self.body.when_then_expr[0].1;
1109
1110        match when_expr.evaluate(batch)? {
1111            // WHEN true --> column
1112            ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))) => {
1113                then_expr.evaluate(batch)
1114            }
1115            // WHEN [false | null] --> NULL
1116            ColumnarValue::Scalar(_) => {
1117                // return scalar NULL value
1118                ScalarValue::try_from(self.data_type(&batch.schema())?)
1119                    .map(ColumnarValue::Scalar)
1120            }
1121            // WHEN column --> column
1122            ColumnarValue::Array(bit_mask) => {
1123                let bit_mask = bit_mask
1124                    .as_any()
1125                    .downcast_ref::<BooleanArray>()
1126                    .expect("predicate should evaluate to a boolean array");
1127                // invert the bitmask
1128                let bit_mask = match bit_mask.null_count() {
1129                    0 => not(bit_mask)?,
1130                    _ => not(&prep_null_mask_filter(bit_mask))?,
1131                };
1132                match then_expr.evaluate(batch)? {
1133                    ColumnarValue::Array(array) => {
1134                        Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?))
1135                    }
1136                    ColumnarValue::Scalar(_) => {
1137                        internal_err!("expression did not evaluate to an array")
1138                    }
1139                }
1140            }
1141        }
1142    }
1143
1144    fn scalar_or_scalar(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
1145        let return_type = self.data_type(&batch.schema())?;
1146
1147        // evaluate when expression
1148        let when_value = self.body.when_then_expr[0].0.evaluate(batch)?;
1149        let when_value = when_value.into_array(batch.num_rows())?;
1150        let when_value = as_boolean_array(&when_value).map_err(|_| {
1151            internal_datafusion_err!("WHEN expression did not return a BooleanArray")
1152        })?;
1153
1154        // Treat 'NULL' as false value
1155        let when_value = match when_value.null_count() {
1156            0 => Cow::Borrowed(when_value),
1157            _ => Cow::Owned(prep_null_mask_filter(when_value)),
1158        };
1159
1160        // evaluate then_value
1161        let then_value = self.body.when_then_expr[0].1.evaluate(batch)?;
1162        let then_value = Scalar::new(then_value.into_array(1)?);
1163
1164        let Some(e) = &self.body.else_expr else {
1165            return internal_err!("expression did not evaluate to an array");
1166        };
1167        // keep `else_expr`'s data type and return type consistent
1168        let expr = try_cast(Arc::clone(e), &batch.schema(), return_type)?;
1169        let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?);
1170        Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?))
1171    }
1172
1173    fn expr_or_expr(
1174        &self,
1175        batch: &RecordBatch,
1176        projected: &ProjectedCaseBody,
1177    ) -> Result<ColumnarValue> {
1178        // evaluate when condition on batch
1179        let when_value = self.body.when_then_expr[0].0.evaluate(batch)?;
1180        // `num_rows == 1` is intentional to avoid expanding scalars.
1181        // If the `when_value` is effectively a scalar, the 'all true' and 'all false' checks
1182        // below will avoid incorrectly using the scalar as a merge/zip mask.
1183        let when_value = when_value.into_array(1)?;
1184        let when_value = as_boolean_array(&when_value).map_err(|e| {
1185            DataFusionError::Context(
1186                "WHEN expression did not return a BooleanArray".to_string(),
1187                Box::new(e),
1188            )
1189        })?;
1190
1191        if when_value.null_count() == 0 && !when_value.has_false() {
1192            // All input rows are true, just call the 'then' expression
1193            self.body.when_then_expr[0].1.evaluate(batch)
1194        } else if !when_value.has_true() {
1195            // All input rows are false/null, just call the 'else' expression
1196            match &self.body.else_expr {
1197                Some(else_expr) => else_expr.evaluate(batch),
1198                None => {
1199                    let return_type = self.data_type(&batch.schema())?;
1200                    Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
1201                        &return_type,
1202                    )?))
1203                }
1204            }
1205        } else {
1206            // projected.projection may include indexes of lambda variables not available on this batch
1207            let projection = projected
1208                .projection
1209                .iter()
1210                .copied()
1211                .filter(|index| *index < batch.num_columns())
1212                .collect::<Vec<_>>();
1213            if projection.len() < batch.num_columns() {
1214                // The case expressions do not use all the columns of the input batch.
1215                // Project first to reduce time spent filtering.
1216                let projected_batch = batch.project(&projection)?;
1217                projected.body.expr_or_expr(&projected_batch, when_value)
1218            } else {
1219                // All columns are used in the case expressions, so there is no need to project.
1220                self.body.expr_or_expr(batch, when_value)
1221            }
1222        }
1223    }
1224
1225    fn with_lookup_table(
1226        &self,
1227        batch: &RecordBatch,
1228        lookup_table: &LiteralLookupTable,
1229    ) -> Result<ColumnarValue> {
1230        let expr = self.body.expr.as_ref().unwrap();
1231        let evaluated_expression = expr.evaluate(batch)?;
1232
1233        let is_scalar = matches!(evaluated_expression, ColumnarValue::Scalar(_));
1234        let evaluated_expression = evaluated_expression.to_array(1)?;
1235
1236        let values = lookup_table.map_keys_to_values(&evaluated_expression)?;
1237
1238        let result = if is_scalar {
1239            ColumnarValue::Scalar(ScalarValue::try_from_array(values.as_ref(), 0)?)
1240        } else {
1241            ColumnarValue::Array(values)
1242        };
1243
1244        Ok(result)
1245    }
1246}
1247
1248impl PhysicalExpr for CaseExpr {
1249    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
1250        self.body.data_type(input_schema)
1251    }
1252
1253    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
1254        let nullable_then = self
1255            .body
1256            .when_then_expr
1257            .iter()
1258            .filter_map(|(w, t)| {
1259                let is_nullable = match t.nullable(input_schema) {
1260                    // Pass on error determining nullability verbatim
1261                    Err(e) => return Some(Err(e)),
1262                    Ok(n) => n,
1263                };
1264
1265                // Branches with a then expression that is not nullable do not impact the
1266                // nullability of the case expression.
1267                if !is_nullable {
1268                    return None;
1269                }
1270
1271                // For case-with-expression assume all 'then' expressions are reachable
1272                if self.body.expr.is_some() {
1273                    return Some(Ok(()));
1274                }
1275
1276                // For branches with a nullable 'then' expression, try to determine
1277                // if the 'then' expression is ever reachable in the situation where
1278                // it would evaluate to null.
1279
1280                // Replace the `then` expression with `NULL` in the `when` expression
1281                let with_null = match replace_with_null(w, t.as_ref(), input_schema) {
1282                    Err(e) => return Some(Err(e)),
1283                    Ok(e) => e,
1284                };
1285
1286                // Try to const evaluate the modified `when` expression.
1287                let predicate_result = match evaluate_predicate(&with_null) {
1288                    Err(e) => return Some(Err(e)),
1289                    Ok(b) => b,
1290                };
1291
1292                match predicate_result {
1293                    // Evaluation was inconclusive or true, so the 'then' expression is reachable
1294                    None | Some(true) => Some(Ok(())),
1295                    // Evaluation proves the branch will never be taken.
1296                    // The most common pattern for this is `WHEN x IS NOT NULL THEN x`.
1297                    Some(false) => None,
1298                }
1299            })
1300            .next();
1301
1302        if let Some(nullable_then) = nullable_then {
1303            // There is at least one reachable nullable 'then' expression, so the case
1304            // expression itself is nullable.
1305            // Use `Result::map` to propagate the error from `nullable_then` if there is one.
1306            nullable_then.map(|_| true)
1307        } else if let Some(e) = &self.body.else_expr {
1308            // There are no reachable nullable 'then' expressions, so all we still need to
1309            // check is the 'else' expression's nullability.
1310            e.nullable(input_schema)
1311        } else {
1312            // CASE produces NULL if there is no `else` expr
1313            // (aka when none of the `when_then_exprs` match)
1314            Ok(true)
1315        }
1316    }
1317
1318    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
1319        match &self.eval_method {
1320            EvalMethod::WithExpression(p) => {
1321                // this use case evaluates "expr" and then compares the values with the "when"
1322                // values
1323                self.case_when_with_expr(batch, p)
1324            }
1325            EvalMethod::NoExpression(p) => {
1326                // The "when" conditions all evaluate to boolean in this use case and can be
1327                // arbitrary expressions
1328                self.case_when_no_expr(batch, p)
1329            }
1330            EvalMethod::InfallibleExprOrNull => {
1331                // Specialization for CASE WHEN expr THEN column [ELSE NULL] END
1332                self.case_column_or_null(batch)
1333            }
1334            EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch),
1335            EvalMethod::ExpressionOrExpression(p) => self.expr_or_expr(batch, p),
1336            EvalMethod::WithExprScalarLookupTable(lookup_table) => {
1337                self.with_lookup_table(batch, lookup_table)
1338            }
1339        }
1340    }
1341
1342    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
1343        let mut children = vec![];
1344        if let Some(expr) = &self.body.expr {
1345            children.push(expr)
1346        }
1347        self.body.when_then_expr.iter().for_each(|(cond, value)| {
1348            children.push(cond);
1349            children.push(value);
1350        });
1351
1352        if let Some(else_expr) = &self.body.else_expr {
1353            children.push(else_expr)
1354        }
1355        children
1356    }
1357
1358    // For physical CaseExpr, we do not allow modifying children size
1359    fn with_new_children(
1360        self: Arc<Self>,
1361        children: Vec<Arc<dyn PhysicalExpr>>,
1362    ) -> Result<Arc<dyn PhysicalExpr>> {
1363        if children.len() != self.children().len() {
1364            internal_err!("CaseExpr: Wrong number of children")
1365        } else {
1366            let (expr, when_then_expr, else_expr) =
1367                match (self.expr().is_some(), self.body.else_expr.is_some()) {
1368                    (true, true) => (
1369                        Some(&children[0]),
1370                        &children[1..children.len() - 1],
1371                        Some(&children[children.len() - 1]),
1372                    ),
1373                    (true, false) => {
1374                        (Some(&children[0]), &children[1..children.len()], None)
1375                    }
1376                    (false, true) => (
1377                        None,
1378                        &children[0..children.len() - 1],
1379                        Some(&children[children.len() - 1]),
1380                    ),
1381                    (false, false) => (None, &children[0..children.len()], None),
1382                };
1383            Ok(Arc::new(CaseExpr::try_new(
1384                expr.cloned(),
1385                when_then_expr.iter().cloned().tuples().collect(),
1386                else_expr.cloned(),
1387            )?))
1388        }
1389    }
1390
1391    fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1392        write!(f, "CASE ")?;
1393        if let Some(e) = &self.body.expr {
1394            e.fmt_sql(f)?;
1395            write!(f, " ")?;
1396        }
1397
1398        for (w, t) in &self.body.when_then_expr {
1399            write!(f, "WHEN ")?;
1400            w.fmt_sql(f)?;
1401            write!(f, " THEN ")?;
1402            t.fmt_sql(f)?;
1403            write!(f, " ")?;
1404        }
1405
1406        if let Some(e) = &self.body.else_expr {
1407            write!(f, "ELSE ")?;
1408            e.fmt_sql(f)?;
1409            write!(f, " ")?;
1410        }
1411        write!(f, "END")
1412    }
1413}
1414
1415/// Attempts to const evaluate the given `predicate`.
1416/// Returns:
1417/// - `Some(true)` if the predicate evaluates to a truthy value.
1418/// - `Some(false)` if the predicate evaluates to a falsy value.
1419/// - `None` if the predicate could not be evaluated.
1420fn evaluate_predicate(predicate: &Arc<dyn PhysicalExpr>) -> Result<Option<bool>> {
1421    // Create a dummy record with no columns and one row
1422    let batch = RecordBatch::try_new_with_options(
1423        Arc::new(Schema::empty()),
1424        vec![],
1425        &RecordBatchOptions::new().with_row_count(Some(1)),
1426    )?;
1427
1428    // Evaluate the predicate and interpret the result as a boolean
1429    let result = match predicate.evaluate(&batch) {
1430        // An error during evaluation means we couldn't const evaluate the predicate, so return `None`
1431        Err(_) => None,
1432        Ok(ColumnarValue::Array(array)) => Some(
1433            ScalarValue::try_from_array(array.as_ref(), 0)?
1434                .cast_to(&DataType::Boolean)?,
1435        ),
1436        Ok(ColumnarValue::Scalar(scalar)) => Some(scalar.cast_to(&DataType::Boolean)?),
1437    };
1438    Ok(result.map(|v| matches!(v, ScalarValue::Boolean(Some(true)))))
1439}
1440
1441fn replace_with_null(
1442    expr: &Arc<dyn PhysicalExpr>,
1443    expr_to_replace: &dyn PhysicalExpr,
1444    input_schema: &Schema,
1445) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
1446    let with_null = Arc::clone(expr)
1447        .transform_down(|e| {
1448            if e.as_ref().dyn_eq(expr_to_replace) {
1449                let data_type = e.data_type(input_schema)?;
1450                let null_literal = lit(ScalarValue::try_new_null(&data_type)?);
1451                Ok(Transformed::yes(null_literal))
1452            } else {
1453                Ok(Transformed::no(e))
1454            }
1455        })?
1456        .data;
1457    Ok(with_null)
1458}
1459
1460/// Create a CASE expression
1461pub fn case(
1462    expr: Option<Arc<dyn PhysicalExpr>>,
1463    when_thens: Vec<WhenThen>,
1464    else_expr: Option<Arc<dyn PhysicalExpr>>,
1465) -> Result<Arc<dyn PhysicalExpr>> {
1466    Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?))
1467}
1468
1469#[cfg(test)]
1470mod tests {
1471    use super::*;
1472
1473    use crate::expressions;
1474    use crate::expressions::{BinaryExpr, binary, cast, col, is_not_null};
1475    use arrow::buffer::Buffer;
1476    use arrow::datatypes::DataType::Float64;
1477    use arrow::datatypes::Field;
1478    use datafusion_common::cast::{as_float64_array, as_int32_array};
1479    use datafusion_common::plan_err;
1480    use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
1481    use datafusion_expr::type_coercion::binary::type_union_coercion;
1482    use datafusion_expr_common::operator::Operator;
1483    use datafusion_physical_expr_common::physical_expr::fmt_sql;
1484    use half::f16;
1485
1486    #[test]
1487    fn case_with_expr() -> Result<()> {
1488        let batch = case_test_batch()?;
1489        let schema = batch.schema();
1490
1491        // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END
1492        let when1 = lit("foo");
1493        let then1 = lit(123i32);
1494        let when2 = lit("bar");
1495        let then2 = lit(456i32);
1496
1497        let expr = generate_case_when_with_type_coercion(
1498            Some(col("a", &schema)?),
1499            vec![(when1, then1), (when2, then2)],
1500            None,
1501            schema.as_ref(),
1502        )?;
1503        let result = expr
1504            .evaluate(&batch)?
1505            .into_array(batch.num_rows())
1506            .expect("Failed to convert to array");
1507        let result = as_int32_array(&result)?;
1508
1509        let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1510
1511        assert_eq!(expected, result);
1512
1513        Ok(())
1514    }
1515
1516    #[test]
1517    fn case_with_expr_dictionary() -> Result<()> {
1518        let schema = Schema::new(vec![Field::new(
1519            "a",
1520            DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
1521            true,
1522        )]);
1523        let keys = UInt8Array::from(vec![0u8, 1u8, 2u8, 3u8]);
1524        let values = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
1525        let dictionary = DictionaryArray::new(keys, Arc::new(values));
1526        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?;
1527
1528        let schema = batch.schema();
1529
1530        // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END
1531        let when1 = lit("foo");
1532        let then1 = lit(123i32);
1533        let when2 = lit("bar");
1534        let then2 = lit(456i32);
1535
1536        let expr = generate_case_when_with_type_coercion(
1537            Some(col("a", &schema)?),
1538            vec![(when1, then1), (when2, then2)],
1539            None,
1540            schema.as_ref(),
1541        )?;
1542        let result = expr
1543            .evaluate(&batch)?
1544            .into_array(batch.num_rows())
1545            .expect("Failed to convert to array");
1546        let result = as_int32_array(&result)?;
1547
1548        let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1549
1550        assert_eq!(expected, result);
1551
1552        Ok(())
1553    }
1554
1555    // Make sure we are not failing when got literal in case when but input is dictionary encoded
1556    #[test]
1557    fn case_with_expr_primitive_dictionary() -> Result<()> {
1558        let schema = Schema::new(vec![Field::new(
1559            "a",
1560            DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::UInt64)),
1561            true,
1562        )]);
1563        let keys = UInt8Array::from(vec![0u8, 1u8, 2u8, 3u8]);
1564        let values = UInt64Array::from(vec![Some(10), Some(20), None, Some(30)]);
1565        let dictionary = DictionaryArray::new(keys, Arc::new(values));
1566        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?;
1567
1568        let schema = batch.schema();
1569
1570        // CASE a WHEN 10 THEN 123 WHEN 30 THEN 456 END
1571        let when1 = lit(10_u64);
1572        let then1 = lit(123_i32);
1573        let when2 = lit(30_u64);
1574        let then2 = lit(456_i32);
1575
1576        let expr = generate_case_when_with_type_coercion(
1577            Some(col("a", &schema)?),
1578            vec![(when1, then1), (when2, then2)],
1579            None,
1580            schema.as_ref(),
1581        )?;
1582        let result = expr
1583            .evaluate(&batch)?
1584            .into_array(batch.num_rows())
1585            .expect("Failed to convert to array");
1586        let result = as_int32_array(&result)?;
1587
1588        let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1589
1590        assert_eq!(expected, result);
1591
1592        Ok(())
1593    }
1594
1595    // Make sure we are not failing when got literal in case when but input is dictionary encoded
1596    #[test]
1597    fn case_with_expr_boolean_dictionary() -> Result<()> {
1598        let schema = Schema::new(vec![Field::new(
1599            "a",
1600            DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Boolean)),
1601            true,
1602        )]);
1603        let keys = UInt8Array::from(vec![0u8, 1u8, 2u8, 3u8]);
1604        let values = BooleanArray::from(vec![Some(true), Some(false), None, Some(true)]);
1605        let dictionary = DictionaryArray::new(keys, Arc::new(values));
1606        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?;
1607
1608        let schema = batch.schema();
1609
1610        // CASE a WHEN true THEN 123 WHEN false THEN 456 END
1611        let when1 = lit(true);
1612        let then1 = lit(123i32);
1613        let when2 = lit(false);
1614        let then2 = lit(456i32);
1615
1616        let expr = generate_case_when_with_type_coercion(
1617            Some(col("a", &schema)?),
1618            vec![(when1, then1), (when2, then2)],
1619            None,
1620            schema.as_ref(),
1621        )?;
1622        let result = expr
1623            .evaluate(&batch)?
1624            .into_array(batch.num_rows())
1625            .expect("Failed to convert to array");
1626        let result = as_int32_array(&result)?;
1627
1628        let expected = &Int32Array::from(vec![Some(123), Some(456), None, Some(123)]);
1629
1630        assert_eq!(expected, result);
1631
1632        Ok(())
1633    }
1634
1635    #[test]
1636    fn case_with_expr_all_null_dictionary() -> Result<()> {
1637        let schema = Schema::new(vec![Field::new(
1638            "a",
1639            DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
1640            true,
1641        )]);
1642        let keys = UInt8Array::from(vec![2u8, 2u8, 2u8, 2u8]);
1643        let values = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
1644        let dictionary = DictionaryArray::new(keys, Arc::new(values));
1645        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?;
1646
1647        let schema = batch.schema();
1648
1649        // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END
1650        let when1 = lit("foo");
1651        let then1 = lit(123i32);
1652        let when2 = lit("bar");
1653        let then2 = lit(456i32);
1654
1655        let expr = generate_case_when_with_type_coercion(
1656            Some(col("a", &schema)?),
1657            vec![(when1, then1), (when2, then2)],
1658            None,
1659            schema.as_ref(),
1660        )?;
1661        let result = expr
1662            .evaluate(&batch)?
1663            .into_array(batch.num_rows())
1664            .expect("Failed to convert to array");
1665        let result = as_int32_array(&result)?;
1666
1667        let expected = &Int32Array::from(vec![None, None, None, None]);
1668
1669        assert_eq!(expected, result);
1670
1671        Ok(())
1672    }
1673
1674    #[test]
1675    fn case_with_expr_else() -> Result<()> {
1676        let batch = case_test_batch()?;
1677        let schema = batch.schema();
1678
1679        // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 ELSE 999 END
1680        let when1 = lit("foo");
1681        let then1 = lit(123i32);
1682        let when2 = lit("bar");
1683        let then2 = lit(456i32);
1684        let else_value = lit(999i32);
1685
1686        let expr = generate_case_when_with_type_coercion(
1687            Some(col("a", &schema)?),
1688            vec![(when1, then1), (when2, then2)],
1689            Some(else_value),
1690            schema.as_ref(),
1691        )?;
1692        let result = expr
1693            .evaluate(&batch)?
1694            .into_array(batch.num_rows())
1695            .expect("Failed to convert to array");
1696        let result = as_int32_array(&result)?;
1697
1698        let expected =
1699            &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
1700
1701        assert_eq!(expected, result);
1702
1703        Ok(())
1704    }
1705
1706    #[test]
1707    fn case_with_expr_divide_by_zero() -> Result<()> {
1708        let batch = case_test_batch1()?;
1709        let schema = batch.schema();
1710
1711        // CASE a when 0 THEN float64(null) ELSE 25.0 / cast(a, float64)  END
1712        let when1 = lit(0i32);
1713        let then1 = lit(ScalarValue::Float64(None));
1714        let else_value = binary(
1715            lit(25.0f64),
1716            Operator::Divide,
1717            cast(col("a", &schema)?, &batch.schema(), Float64)?,
1718            &batch.schema(),
1719        )?;
1720
1721        let expr = generate_case_when_with_type_coercion(
1722            Some(col("a", &schema)?),
1723            vec![(when1, then1)],
1724            Some(else_value),
1725            schema.as_ref(),
1726        )?;
1727        let result = expr
1728            .evaluate(&batch)?
1729            .into_array(batch.num_rows())
1730            .expect("Failed to convert to array");
1731        let result =
1732            as_float64_array(&result).expect("failed to downcast to Float64Array");
1733
1734        let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
1735
1736        assert_eq!(expected, result);
1737
1738        Ok(())
1739    }
1740
1741    #[test]
1742    fn case_without_expr() -> Result<()> {
1743        let batch = case_test_batch()?;
1744        let schema = batch.schema();
1745
1746        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 END
1747        let when1 = binary(
1748            col("a", &schema)?,
1749            Operator::Eq,
1750            lit("foo"),
1751            &batch.schema(),
1752        )?;
1753        let then1 = lit(123i32);
1754        let when2 = binary(
1755            col("a", &schema)?,
1756            Operator::Eq,
1757            lit("bar"),
1758            &batch.schema(),
1759        )?;
1760        let then2 = lit(456i32);
1761
1762        let expr = generate_case_when_with_type_coercion(
1763            None,
1764            vec![(when1, then1), (when2, then2)],
1765            None,
1766            schema.as_ref(),
1767        )?;
1768        let result = expr
1769            .evaluate(&batch)?
1770            .into_array(batch.num_rows())
1771            .expect("Failed to convert to array");
1772        let result = as_int32_array(&result)?;
1773
1774        let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1775
1776        assert_eq!(expected, result);
1777
1778        Ok(())
1779    }
1780
1781    #[test]
1782    fn case_with_expr_when_null() -> Result<()> {
1783        let batch = case_test_batch()?;
1784        let schema = batch.schema();
1785
1786        // CASE a WHEN NULL THEN 0 WHEN a THEN 123 ELSE 999 END
1787        let when1 = lit(ScalarValue::Utf8(None));
1788        let then1 = lit(0i32);
1789        let when2 = col("a", &schema)?;
1790        let then2 = lit(123i32);
1791        let else_value = lit(999i32);
1792
1793        let expr = generate_case_when_with_type_coercion(
1794            Some(col("a", &schema)?),
1795            vec![(when1, then1), (when2, then2)],
1796            Some(else_value),
1797            schema.as_ref(),
1798        )?;
1799        let result = expr
1800            .evaluate(&batch)?
1801            .into_array(batch.num_rows())
1802            .expect("Failed to convert to array");
1803        let result = as_int32_array(&result)?;
1804
1805        let expected =
1806            &Int32Array::from(vec![Some(123), Some(123), Some(999), Some(123)]);
1807
1808        assert_eq!(expected, result);
1809
1810        Ok(())
1811    }
1812
1813    #[test]
1814    fn case_without_expr_divide_by_zero() -> Result<()> {
1815        let batch = case_test_batch1()?;
1816        let schema = batch.schema();
1817
1818        // CASE WHEN a > 0 THEN 25.0 / cast(a, float64) ELSE float64(null) END
1819        let when1 = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &batch.schema())?;
1820        let then1 = binary(
1821            lit(25.0f64),
1822            Operator::Divide,
1823            cast(col("a", &schema)?, &batch.schema(), Float64)?,
1824            &batch.schema(),
1825        )?;
1826        let x = lit(ScalarValue::Float64(None));
1827
1828        let expr = generate_case_when_with_type_coercion(
1829            None,
1830            vec![(when1, then1)],
1831            Some(x),
1832            schema.as_ref(),
1833        )?;
1834        let result = expr
1835            .evaluate(&batch)?
1836            .into_array(batch.num_rows())
1837            .expect("Failed to convert to array");
1838        let result =
1839            as_float64_array(&result).expect("failed to downcast to Float64Array");
1840
1841        let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
1842
1843        assert_eq!(expected, result);
1844
1845        Ok(())
1846    }
1847
1848    fn case_test_batch1() -> Result<RecordBatch> {
1849        let schema = Schema::new(vec![
1850            Field::new("a", DataType::Int32, true),
1851            Field::new("b", DataType::Int32, true),
1852            Field::new("c", DataType::Int32, true),
1853        ]);
1854        let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]);
1855        let b = Int32Array::from(vec![Some(3), None, Some(14), Some(7)]);
1856        let c = Int32Array::from(vec![Some(0), Some(-3), Some(777), None]);
1857        let batch = RecordBatch::try_new(
1858            Arc::new(schema),
1859            vec![Arc::new(a), Arc::new(b), Arc::new(c)],
1860        )?;
1861        Ok(batch)
1862    }
1863
1864    #[test]
1865    fn case_without_expr_else() -> Result<()> {
1866        let batch = case_test_batch()?;
1867        let schema = batch.schema();
1868
1869        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 999 END
1870        let when1 = binary(
1871            col("a", &schema)?,
1872            Operator::Eq,
1873            lit("foo"),
1874            &batch.schema(),
1875        )?;
1876        let then1 = lit(123i32);
1877        let when2 = binary(
1878            col("a", &schema)?,
1879            Operator::Eq,
1880            lit("bar"),
1881            &batch.schema(),
1882        )?;
1883        let then2 = lit(456i32);
1884        let else_value = lit(999i32);
1885
1886        let expr = generate_case_when_with_type_coercion(
1887            None,
1888            vec![(when1, then1), (when2, then2)],
1889            Some(else_value),
1890            schema.as_ref(),
1891        )?;
1892        let result = expr
1893            .evaluate(&batch)?
1894            .into_array(batch.num_rows())
1895            .expect("Failed to convert to array");
1896        let result = as_int32_array(&result)?;
1897
1898        let expected =
1899            &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
1900
1901        assert_eq!(expected, result);
1902
1903        Ok(())
1904    }
1905
1906    #[test]
1907    fn case_with_type_cast() -> Result<()> {
1908        let batch = case_test_batch()?;
1909        let schema = batch.schema();
1910
1911        // CASE WHEN a = 'foo' THEN 123.3 ELSE 999 END
1912        let when = binary(
1913            col("a", &schema)?,
1914            Operator::Eq,
1915            lit("foo"),
1916            &batch.schema(),
1917        )?;
1918        let then = lit(123.3f64);
1919        let else_value = lit(999i32);
1920
1921        let expr = generate_case_when_with_type_coercion(
1922            None,
1923            vec![(when, then)],
1924            Some(else_value),
1925            schema.as_ref(),
1926        )?;
1927        let result = expr
1928            .evaluate(&batch)?
1929            .into_array(batch.num_rows())
1930            .expect("Failed to convert to array");
1931        let result =
1932            as_float64_array(&result).expect("failed to downcast to Float64Array");
1933
1934        let expected =
1935            &Float64Array::from(vec![Some(123.3), Some(999.0), Some(999.0), Some(999.0)]);
1936
1937        assert_eq!(expected, result);
1938
1939        Ok(())
1940    }
1941
1942    #[test]
1943    fn case_with_matches_and_nulls() -> Result<()> {
1944        let batch = case_test_batch_nulls()?;
1945        let schema = batch.schema();
1946
1947        // SELECT CASE WHEN load4 = 1.77 THEN load4 END
1948        let when = binary(
1949            col("load4", &schema)?,
1950            Operator::Eq,
1951            lit(1.77f64),
1952            &batch.schema(),
1953        )?;
1954        let then = col("load4", &schema)?;
1955
1956        let expr = generate_case_when_with_type_coercion(
1957            None,
1958            vec![(when, then)],
1959            None,
1960            schema.as_ref(),
1961        )?;
1962        let result = expr
1963            .evaluate(&batch)?
1964            .into_array(batch.num_rows())
1965            .expect("Failed to convert to array");
1966        let result =
1967            as_float64_array(&result).expect("failed to downcast to Float64Array");
1968
1969        let expected =
1970            &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
1971
1972        assert_eq!(expected, result);
1973
1974        Ok(())
1975    }
1976
1977    #[test]
1978    fn case_with_scalar_predicate() -> Result<()> {
1979        let batch = case_test_batch_nulls()?;
1980        let schema = batch.schema();
1981
1982        // SELECT CASE WHEN TRUE THEN load4 END
1983        let when = lit(true);
1984        let then = col("load4", &schema)?;
1985        let expr = generate_case_when_with_type_coercion(
1986            None,
1987            vec![(when, then)],
1988            None,
1989            schema.as_ref(),
1990        )?;
1991
1992        // many rows
1993        let result = expr
1994            .evaluate(&batch)?
1995            .into_array(batch.num_rows())
1996            .expect("Failed to convert to array");
1997        let result =
1998            as_float64_array(&result).expect("failed to downcast to Float64Array");
1999        let expected = &Float64Array::from(vec![
2000            Some(1.77),
2001            None,
2002            None,
2003            Some(1.78),
2004            None,
2005            Some(1.77),
2006        ]);
2007        assert_eq!(expected, result);
2008
2009        // one row
2010        let expected = Float64Array::from(vec![Some(1.1)]);
2011        let batch =
2012            RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(expected.clone())])?;
2013        let result = expr
2014            .evaluate(&batch)?
2015            .into_array(batch.num_rows())
2016            .expect("Failed to convert to array");
2017        let result =
2018            as_float64_array(&result).expect("failed to downcast to Float64Array");
2019        assert_eq!(&expected, result);
2020
2021        Ok(())
2022    }
2023
2024    #[test]
2025    fn case_expr_matches_and_nulls() -> Result<()> {
2026        let batch = case_test_batch_nulls()?;
2027        let schema = batch.schema();
2028
2029        // SELECT CASE load4 WHEN 1.77 THEN load4 END
2030        let expr = col("load4", &schema)?;
2031        let when = lit(1.77f64);
2032        let then = col("load4", &schema)?;
2033
2034        let expr = generate_case_when_with_type_coercion(
2035            Some(expr),
2036            vec![(when, then)],
2037            None,
2038            schema.as_ref(),
2039        )?;
2040        let result = expr
2041            .evaluate(&batch)?
2042            .into_array(batch.num_rows())
2043            .expect("Failed to convert to array");
2044        let result =
2045            as_float64_array(&result).expect("failed to downcast to Float64Array");
2046
2047        let expected =
2048            &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
2049
2050        assert_eq!(expected, result);
2051
2052        Ok(())
2053    }
2054
2055    #[test]
2056    fn test_when_null_and_some_cond_else_null() -> Result<()> {
2057        let batch = case_test_batch()?;
2058        let schema = batch.schema();
2059
2060        let when = binary(
2061            Arc::new(Literal::new(ScalarValue::Boolean(None))),
2062            Operator::And,
2063            binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?,
2064            &schema,
2065        )?;
2066        let then = col("a", &schema)?;
2067
2068        // SELECT CASE WHEN (NULL AND a = 'foo') THEN a ELSE NULL END
2069        let expr = Arc::new(CaseExpr::try_new(None, vec![(when, then)], None)?);
2070        let result = expr
2071            .evaluate(&batch)?
2072            .into_array(batch.num_rows())
2073            .expect("Failed to convert to array");
2074        let result = as_string_array(&result);
2075
2076        // all result values should be null
2077        assert_eq!(result.logical_null_count(), batch.num_rows());
2078        Ok(())
2079    }
2080
2081    fn case_test_batch() -> Result<RecordBatch> {
2082        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
2083        let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
2084        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
2085        Ok(batch)
2086    }
2087
2088    // Construct an array that has several NULL values whose
2089    // underlying buffer actually matches the where expr predicate
2090    fn case_test_batch_nulls() -> Result<RecordBatch> {
2091        let load4: Float64Array = vec![
2092            Some(1.77), // 1.77
2093            Some(1.77), // null <-- same value, but will be set to null
2094            Some(1.77), // null <-- same value, but will be set to null
2095            Some(1.78), // 1.78
2096            None,       // null
2097            Some(1.77), // 1.77
2098        ]
2099        .into_iter()
2100        .collect();
2101
2102        let null_buffer = Buffer::from([0b00101001u8]);
2103        let load4 = load4
2104            .into_data()
2105            .into_builder()
2106            .null_bit_buffer(Some(null_buffer))
2107            .build()
2108            .unwrap();
2109        let load4: Float64Array = load4.into();
2110
2111        let batch =
2112            RecordBatch::try_from_iter(vec![("load4", Arc::new(load4) as ArrayRef)])?;
2113        Ok(batch)
2114    }
2115
2116    #[test]
2117    fn case_test_incompatible() -> Result<()> {
2118        // 1 then is int64
2119        // 2 then is boolean
2120        let batch = case_test_batch()?;
2121        let schema = batch.schema();
2122
2123        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN true END
2124        let when1 = binary(
2125            col("a", &schema)?,
2126            Operator::Eq,
2127            lit("foo"),
2128            &batch.schema(),
2129        )?;
2130        let then1 = lit(123i32);
2131        let when2 = binary(
2132            col("a", &schema)?,
2133            Operator::Eq,
2134            lit("bar"),
2135            &batch.schema(),
2136        )?;
2137        let then2 = lit(true);
2138
2139        let expr = generate_case_when_with_type_coercion(
2140            None,
2141            vec![(when1, then1), (when2, then2)],
2142            None,
2143            schema.as_ref(),
2144        );
2145        assert!(expr.is_err());
2146
2147        // then 1 is int32
2148        // then 2 is int64
2149        // else is float
2150        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 1.23 END
2151        let when1 = binary(
2152            col("a", &schema)?,
2153            Operator::Eq,
2154            lit("foo"),
2155            &batch.schema(),
2156        )?;
2157        let then1 = lit(123i32);
2158        let when2 = binary(
2159            col("a", &schema)?,
2160            Operator::Eq,
2161            lit("bar"),
2162            &batch.schema(),
2163        )?;
2164        let then2 = lit(456i64);
2165        let else_expr = lit(1.23f64);
2166
2167        let expr = generate_case_when_with_type_coercion(
2168            None,
2169            vec![(when1, then1), (when2, then2)],
2170            Some(else_expr),
2171            schema.as_ref(),
2172        );
2173        assert!(expr.is_ok());
2174        let result_type = expr.unwrap().data_type(schema.as_ref())?;
2175        assert_eq!(Float64, result_type);
2176        Ok(())
2177    }
2178
2179    #[test]
2180    fn case_eq() -> Result<()> {
2181        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
2182
2183        let when1 = lit("foo");
2184        let then1 = lit(123i32);
2185        let when2 = lit("bar");
2186        let then2 = lit(456i32);
2187        let else_value = lit(999i32);
2188
2189        let expr1 = generate_case_when_with_type_coercion(
2190            Some(col("a", &schema)?),
2191            vec![
2192                (Arc::clone(&when1), Arc::clone(&then1)),
2193                (Arc::clone(&when2), Arc::clone(&then2)),
2194            ],
2195            Some(Arc::clone(&else_value)),
2196            &schema,
2197        )?;
2198
2199        let expr2 = generate_case_when_with_type_coercion(
2200            Some(col("a", &schema)?),
2201            vec![
2202                (Arc::clone(&when1), Arc::clone(&then1)),
2203                (Arc::clone(&when2), Arc::clone(&then2)),
2204            ],
2205            Some(Arc::clone(&else_value)),
2206            &schema,
2207        )?;
2208
2209        let expr3 = generate_case_when_with_type_coercion(
2210            Some(col("a", &schema)?),
2211            vec![(Arc::clone(&when1), Arc::clone(&then1)), (when2, then2)],
2212            None,
2213            &schema,
2214        )?;
2215
2216        let expr4 = generate_case_when_with_type_coercion(
2217            Some(col("a", &schema)?),
2218            vec![(when1, then1)],
2219            Some(else_value),
2220            &schema,
2221        )?;
2222
2223        assert!(expr1.eq(&expr2));
2224        assert!(expr2.eq(&expr1));
2225
2226        assert!(expr2.ne(&expr3));
2227        assert!(expr3.ne(&expr2));
2228
2229        assert!(expr1.ne(&expr4));
2230        assert!(expr4.ne(&expr1));
2231
2232        Ok(())
2233    }
2234
2235    #[test]
2236    fn case_transform() -> Result<()> {
2237        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
2238
2239        let when1 = lit("foo");
2240        let then1 = lit(123i32);
2241        let when2 = lit("bar");
2242        let then2 = lit(456i32);
2243        let else_value = lit(999i32);
2244
2245        let expr = generate_case_when_with_type_coercion(
2246            Some(col("a", &schema)?),
2247            vec![
2248                (Arc::clone(&when1), Arc::clone(&then1)),
2249                (Arc::clone(&when2), Arc::clone(&then2)),
2250            ],
2251            Some(Arc::clone(&else_value)),
2252            &schema,
2253        )?;
2254
2255        let expr2 = Arc::clone(&expr)
2256            .transform(|e| {
2257                let transformed = match e.downcast_ref::<Literal>() {
2258                    Some(lit_value) => match lit_value.value() {
2259                        ScalarValue::Utf8(Some(str_value)) => {
2260                            Some(lit(str_value.to_uppercase()))
2261                        }
2262                        _ => None,
2263                    },
2264                    _ => None,
2265                };
2266                Ok(if let Some(transformed) = transformed {
2267                    Transformed::yes(transformed)
2268                } else {
2269                    Transformed::no(e)
2270                })
2271            })
2272            .data()
2273            .unwrap();
2274
2275        let expr3 = Arc::clone(&expr)
2276            .transform_down(|e| {
2277                let transformed = match e.downcast_ref::<Literal>() {
2278                    Some(lit_value) => match lit_value.value() {
2279                        ScalarValue::Utf8(Some(str_value)) => {
2280                            Some(lit(str_value.to_uppercase()))
2281                        }
2282                        _ => None,
2283                    },
2284                    _ => None,
2285                };
2286                Ok(if let Some(transformed) = transformed {
2287                    Transformed::yes(transformed)
2288                } else {
2289                    Transformed::no(e)
2290                })
2291            })
2292            .data()
2293            .unwrap();
2294
2295        assert!(expr.ne(&expr2));
2296        assert!(expr2.eq(&expr3));
2297
2298        Ok(())
2299    }
2300
2301    #[test]
2302    fn test_column_or_null_specialization() -> Result<()> {
2303        // create input data
2304        let mut c1 = Int32Builder::new();
2305        let mut c2 = StringBuilder::new();
2306        for i in 0..1000 {
2307            c1.append_value(i);
2308            if i % 7 == 0 {
2309                c2.append_null();
2310            } else {
2311                c2.append_value(format!("string {i}"));
2312            }
2313        }
2314        let c1 = Arc::new(c1.finish());
2315        let c2 = Arc::new(c2.finish());
2316        let schema = Schema::new(vec![
2317            Field::new("c1", DataType::Int32, true),
2318            Field::new("c2", DataType::Utf8, true),
2319        ]);
2320        let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap();
2321
2322        // CaseWhenExprOrNull should produce same results as CaseExpr
2323        let predicate = Arc::new(BinaryExpr::new(
2324            make_col("c1", 0),
2325            Operator::LtEq,
2326            make_lit_i32(250),
2327        ));
2328        let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?;
2329        assert_eq!(expr.eval_method, EvalMethod::InfallibleExprOrNull);
2330        match expr.evaluate(&batch)? {
2331            ColumnarValue::Array(array) => {
2332                assert_eq!(1000, array.len());
2333                assert_eq!(785, array.null_count());
2334            }
2335            _ => unreachable!(),
2336        }
2337        Ok(())
2338    }
2339
2340    #[test]
2341    fn test_expr_or_expr_specialization() -> Result<()> {
2342        let batch = case_test_batch1()?;
2343        let schema = batch.schema();
2344        let when = binary(
2345            col("a", &schema)?,
2346            Operator::LtEq,
2347            lit(2i32),
2348            &batch.schema(),
2349        )?;
2350        let then = col("b", &schema)?;
2351        let else_expr = col("c", &schema)?;
2352        let expr = CaseExpr::try_new(None, vec![(when, then)], Some(else_expr))?;
2353        assert!(matches!(
2354            expr.eval_method,
2355            EvalMethod::ExpressionOrExpression(_)
2356        ));
2357        let result = expr
2358            .evaluate(&batch)?
2359            .into_array(batch.num_rows())
2360            .expect("Failed to convert to array");
2361        let result = as_int32_array(&result).expect("failed to downcast to Int32Array");
2362
2363        let expected = &Int32Array::from(vec![Some(3), None, Some(777), None]);
2364
2365        assert_eq!(expected, result);
2366        Ok(())
2367    }
2368
2369    fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
2370        Arc::new(Column::new(name, index))
2371    }
2372
2373    fn make_lit_i32(n: i32) -> Arc<dyn PhysicalExpr> {
2374        Arc::new(Literal::new(ScalarValue::Int32(Some(n))))
2375    }
2376
2377    fn generate_case_when_with_type_coercion(
2378        expr: Option<Arc<dyn PhysicalExpr>>,
2379        when_thens: Vec<WhenThen>,
2380        else_expr: Option<Arc<dyn PhysicalExpr>>,
2381        input_schema: &Schema,
2382    ) -> Result<Arc<dyn PhysicalExpr>> {
2383        let coerce_type =
2384            get_case_common_type(&when_thens, else_expr.clone(), input_schema);
2385        let (when_thens, else_expr) = match coerce_type {
2386            None => plan_err!(
2387                "Can't get a common type for then {when_thens:?} and else {else_expr:?} expression"
2388            ),
2389            Some(data_type) => {
2390                // cast then expr
2391                let left = when_thens
2392                    .into_iter()
2393                    .map(|(when, then)| {
2394                        let then = try_cast(then, input_schema, data_type.clone())?;
2395                        Ok((when, then))
2396                    })
2397                    .collect::<Result<Vec<_>>>()?;
2398                let right = match else_expr {
2399                    None => None,
2400                    Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?),
2401                };
2402
2403                Ok((left, right))
2404            }
2405        }?;
2406        case(expr, when_thens, else_expr)
2407    }
2408
2409    fn get_case_common_type(
2410        when_thens: &[WhenThen],
2411        else_expr: Option<Arc<dyn PhysicalExpr>>,
2412        input_schema: &Schema,
2413    ) -> Option<DataType> {
2414        let thens_type = when_thens
2415            .iter()
2416            .map(|when_then| {
2417                let data_type = &when_then.1.data_type(input_schema).unwrap();
2418                data_type.clone()
2419            })
2420            .collect::<Vec<_>>();
2421        let else_type = match else_expr {
2422            None => {
2423                // case when then exprs must have one then value
2424                thens_type[0].clone()
2425            }
2426            Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(),
2427        };
2428        thens_type
2429            .iter()
2430            .try_fold(else_type, |left_type, right_type| {
2431                type_union_coercion(&left_type, right_type)
2432            })
2433    }
2434
2435    #[test]
2436    fn test_fmt_sql() -> Result<()> {
2437        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
2438
2439        // CASE WHEN a = 'foo' THEN 123.3 ELSE 999 END
2440        let when = binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?;
2441        let then = lit(123.3f64);
2442        let else_value = lit(999i32);
2443
2444        let expr = generate_case_when_with_type_coercion(
2445            None,
2446            vec![(when, then)],
2447            Some(else_value),
2448            &schema,
2449        )?;
2450
2451        let display_string = expr.to_string();
2452        assert_eq!(
2453            display_string,
2454            "CASE WHEN a@0 = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END"
2455        );
2456
2457        let sql_string = fmt_sql(expr.as_ref()).to_string();
2458        assert_eq!(
2459            sql_string,
2460            "CASE WHEN a = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END"
2461        );
2462
2463        Ok(())
2464    }
2465
2466    fn when_then_else(
2467        when: &Arc<dyn PhysicalExpr>,
2468        then: &Arc<dyn PhysicalExpr>,
2469        els: &Arc<dyn PhysicalExpr>,
2470    ) -> Result<Arc<dyn PhysicalExpr>> {
2471        let case = CaseExpr::try_new(
2472            None,
2473            vec![(Arc::clone(when), Arc::clone(then))],
2474            Some(Arc::clone(els)),
2475        )?;
2476        Ok(Arc::new(case))
2477    }
2478
2479    #[test]
2480    fn test_case_expression_nullability_with_nullable_column() -> Result<()> {
2481        case_expression_nullability(true)
2482    }
2483
2484    #[test]
2485    fn test_case_expression_nullability_with_not_nullable_column() -> Result<()> {
2486        case_expression_nullability(false)
2487    }
2488
2489    fn case_expression_nullability(col_is_nullable: bool) -> Result<()> {
2490        let schema =
2491            Schema::new(vec![Field::new("foo", DataType::Int32, col_is_nullable)]);
2492
2493        let foo = col("foo", &schema)?;
2494        let foo_is_not_null = is_not_null(Arc::clone(&foo))?;
2495        let foo_is_null = expressions::is_null(Arc::clone(&foo))?;
2496        let not_foo_is_null = expressions::not(Arc::clone(&foo_is_null))?;
2497        let zero = lit(0);
2498        let foo_eq_zero =
2499            binary(Arc::clone(&foo), Operator::Eq, Arc::clone(&zero), &schema)?;
2500
2501        assert_not_nullable(when_then_else(&foo_is_not_null, &foo, &zero)?, &schema);
2502        assert_not_nullable(when_then_else(&not_foo_is_null, &foo, &zero)?, &schema);
2503        assert_not_nullable(when_then_else(&foo_eq_zero, &foo, &zero)?, &schema);
2504
2505        assert_not_nullable(
2506            when_then_else(
2507                &binary(
2508                    Arc::clone(&foo_is_not_null),
2509                    Operator::And,
2510                    Arc::clone(&foo_eq_zero),
2511                    &schema,
2512                )?,
2513                &foo,
2514                &zero,
2515            )?,
2516            &schema,
2517        );
2518
2519        assert_not_nullable(
2520            when_then_else(
2521                &binary(
2522                    Arc::clone(&foo_eq_zero),
2523                    Operator::And,
2524                    Arc::clone(&foo_is_not_null),
2525                    &schema,
2526                )?,
2527                &foo,
2528                &zero,
2529            )?,
2530            &schema,
2531        );
2532
2533        assert_not_nullable(
2534            when_then_else(
2535                &binary(
2536                    Arc::clone(&foo_is_not_null),
2537                    Operator::Or,
2538                    Arc::clone(&foo_eq_zero),
2539                    &schema,
2540                )?,
2541                &foo,
2542                &zero,
2543            )?,
2544            &schema,
2545        );
2546
2547        assert_not_nullable(
2548            when_then_else(
2549                &binary(
2550                    Arc::clone(&foo_eq_zero),
2551                    Operator::Or,
2552                    Arc::clone(&foo_is_not_null),
2553                    &schema,
2554                )?,
2555                &foo,
2556                &zero,
2557            )?,
2558            &schema,
2559        );
2560
2561        assert_nullability(
2562            when_then_else(
2563                &binary(
2564                    Arc::clone(&foo_is_null),
2565                    Operator::Or,
2566                    Arc::clone(&foo_eq_zero),
2567                    &schema,
2568                )?,
2569                &foo,
2570                &zero,
2571            )?,
2572            &schema,
2573            col_is_nullable,
2574        );
2575
2576        assert_nullability(
2577            when_then_else(
2578                &binary(
2579                    binary(Arc::clone(&foo), Operator::Eq, Arc::clone(&zero), &schema)?,
2580                    Operator::Or,
2581                    Arc::clone(&foo_is_null),
2582                    &schema,
2583                )?,
2584                &foo,
2585                &zero,
2586            )?,
2587            &schema,
2588            col_is_nullable,
2589        );
2590
2591        assert_not_nullable(
2592            when_then_else(
2593                &binary(
2594                    binary(
2595                        binary(
2596                            Arc::clone(&foo),
2597                            Operator::Eq,
2598                            Arc::clone(&zero),
2599                            &schema,
2600                        )?,
2601                        Operator::And,
2602                        Arc::clone(&foo_is_not_null),
2603                        &schema,
2604                    )?,
2605                    Operator::Or,
2606                    binary(
2607                        binary(
2608                            Arc::clone(&foo),
2609                            Operator::Eq,
2610                            Arc::clone(&foo),
2611                            &schema,
2612                        )?,
2613                        Operator::And,
2614                        Arc::clone(&foo_is_not_null),
2615                        &schema,
2616                    )?,
2617                    &schema,
2618                )?,
2619                &foo,
2620                &zero,
2621            )?,
2622            &schema,
2623        );
2624
2625        Ok(())
2626    }
2627
2628    fn assert_not_nullable(expr: Arc<dyn PhysicalExpr>, schema: &Schema) {
2629        assert!(!expr.nullable(schema).unwrap());
2630    }
2631
2632    fn assert_nullable(expr: Arc<dyn PhysicalExpr>, schema: &Schema) {
2633        assert!(expr.nullable(schema).unwrap());
2634    }
2635
2636    fn assert_nullability(expr: Arc<dyn PhysicalExpr>, schema: &Schema, nullable: bool) {
2637        if nullable {
2638            assert_nullable(expr, schema);
2639        } else {
2640            assert_not_nullable(expr, schema);
2641        }
2642    }
2643
2644    // Test Lookup evaluation
2645
2646    fn test_case_when_literal_lookup(
2647        values: ArrayRef,
2648        lookup_map: &[(ScalarValue, ScalarValue)],
2649        else_value: Option<ScalarValue>,
2650        expected: ArrayRef,
2651    ) {
2652        // Create lookup
2653        // CASE <expr>
2654        // WHEN <when_constant_1> THEN <then_constant_1>
2655        // WHEN <when_constant_2> THEN <then_constant_2>
2656        // [ ELSE <else_constant> ]
2657
2658        let schema = Schema::new(vec![Field::new(
2659            "a",
2660            values.data_type().clone(),
2661            values.is_nullable(),
2662        )]);
2663        let schema = Arc::new(schema);
2664
2665        let batch = RecordBatch::try_new(schema, vec![values])
2666            .expect("failed to create RecordBatch");
2667
2668        let schema = batch.schema_ref();
2669        let case = col("a", schema).expect("failed to create col");
2670
2671        let when_then = lookup_map
2672            .iter()
2673            .map(|(when, then)| {
2674                (
2675                    Arc::new(Literal::new(when.clone())) as _,
2676                    Arc::new(Literal::new(then.clone())) as _,
2677                )
2678            })
2679            .collect::<Vec<WhenThen>>();
2680
2681        let else_expr = else_value.map(|else_value| {
2682            Arc::new(Literal::new(else_value)) as Arc<dyn PhysicalExpr>
2683        });
2684        let expr = CaseExpr::try_new(Some(case), when_then, else_expr)
2685            .expect("failed to create case");
2686
2687        // Assert that we are testing what we intend to assert
2688        assert!(
2689            matches!(
2690                expr.eval_method,
2691                EvalMethod::WithExprScalarLookupTable { .. }
2692            ),
2693            "we should use the expected eval method"
2694        );
2695
2696        let actual = expr
2697            .evaluate(&batch)
2698            .expect("failed to evaluate case")
2699            .into_array(batch.num_rows())
2700            .expect("Failed to convert to array");
2701
2702        assert_eq!(
2703            actual.data_type(),
2704            expected.data_type(),
2705            "Data type mismatch"
2706        );
2707
2708        assert_eq!(
2709            actual.as_ref(),
2710            expected.as_ref(),
2711            "actual (left) does not match expected (right)"
2712        );
2713    }
2714
2715    fn create_lookup<When, Then>(
2716        when_then_pairs: impl IntoIterator<Item = (When, Then)>,
2717    ) -> Vec<(ScalarValue, ScalarValue)>
2718    where
2719        ScalarValue: From<When>,
2720        ScalarValue: From<Then>,
2721    {
2722        when_then_pairs
2723            .into_iter()
2724            .map(|(when, then)| (ScalarValue::from(when), ScalarValue::from(then)))
2725            .collect()
2726    }
2727
2728    fn create_input_and_expected<Input, Expected, InputFromItem, ExpectedFromItem>(
2729        input_and_expected_pairs: impl IntoIterator<Item = (InputFromItem, ExpectedFromItem)>,
2730    ) -> (Input, Expected)
2731    where
2732        Input: Array + From<Vec<InputFromItem>>,
2733        Expected: Array + From<Vec<ExpectedFromItem>>,
2734    {
2735        let (input_items, expected_items): (Vec<InputFromItem>, Vec<ExpectedFromItem>) =
2736            input_and_expected_pairs.into_iter().unzip();
2737
2738        (Input::from(input_items), Expected::from(expected_items))
2739    }
2740
2741    fn test_lookup_eval_with_and_without_else(
2742        lookup_map: &[(ScalarValue, ScalarValue)],
2743        input_values: ArrayRef,
2744        expected: StringArray,
2745    ) {
2746        // Testing without ELSE should fallback to None
2747        test_case_when_literal_lookup(
2748            Arc::clone(&input_values),
2749            lookup_map,
2750            None,
2751            Arc::new(expected.clone()),
2752        );
2753
2754        // Testing with Else
2755        let else_value = "___fallback___";
2756
2757        // Changing each expected None to be fallback
2758        let expected_with_else = expected
2759            .iter()
2760            .map(|item| item.unwrap_or(else_value))
2761            .map(Some)
2762            .collect::<StringArray>();
2763
2764        // Test case
2765        test_case_when_literal_lookup(
2766            input_values,
2767            lookup_map,
2768            Some(ScalarValue::Utf8(Some(else_value.to_string()))),
2769            Arc::new(expected_with_else),
2770        );
2771    }
2772
2773    #[test]
2774    fn test_case_when_literal_lookup_int32_to_string() {
2775        let lookup_map = create_lookup([
2776            (Some(4), Some("four")),
2777            (Some(2), Some("two")),
2778            (Some(3), Some("three")),
2779            (Some(1), Some("one")),
2780        ]);
2781
2782        let (input_values, expected) =
2783            create_input_and_expected::<Int32Array, StringArray, _, _>([
2784                (1, Some("one")),
2785                (2, Some("two")),
2786                (3, Some("three")),
2787                (3, Some("three")),
2788                (2, Some("two")),
2789                (3, Some("three")),
2790                (5, None), // No match in WHEN
2791                (5, None), // No match in WHEN
2792                (3, Some("three")),
2793                (5, None), // No match in WHEN
2794            ]);
2795
2796        test_lookup_eval_with_and_without_else(
2797            &lookup_map,
2798            Arc::new(input_values),
2799            expected,
2800        );
2801    }
2802
2803    #[test]
2804    fn test_case_when_literal_lookup_none_case_should_never_match() {
2805        let lookup_map = create_lookup([
2806            (Some(4), Some("four")),
2807            (None, Some("none")),
2808            (Some(2), Some("two")),
2809            (Some(1), Some("one")),
2810        ]);
2811
2812        let (input_values, expected) =
2813            create_input_and_expected::<Int32Array, StringArray, _, _>([
2814                (Some(1), Some("one")),
2815                (Some(5), None), // No match in WHEN
2816                (None, None), // None cases are never match in CASE <expr> WHEN <value> syntax
2817                (Some(2), Some("two")),
2818                (None, None), // None cases are never match in CASE <expr> WHEN <value> syntax
2819                (None, None), // None cases are never match in CASE <expr> WHEN <value> syntax
2820                (Some(2), Some("two")),
2821                (Some(5), None), // No match in WHEN
2822            ]);
2823
2824        test_lookup_eval_with_and_without_else(
2825            &lookup_map,
2826            Arc::new(input_values),
2827            expected,
2828        );
2829    }
2830
2831    #[test]
2832    fn test_case_when_literal_lookup_int32_to_string_with_duplicate_cases() {
2833        let lookup_map = create_lookup([
2834            (Some(4), Some("four")),
2835            (Some(4), Some("no 4")),
2836            (Some(2), Some("two")),
2837            (Some(2), Some("no 2")),
2838            (Some(3), Some("three")),
2839            (Some(3), Some("no 3")),
2840            (Some(2), Some("no 2")),
2841            (Some(4), Some("no 4")),
2842            (Some(2), Some("no 2")),
2843            (Some(3), Some("no 3")),
2844            (Some(4), Some("no 4")),
2845            (Some(2), Some("no 2")),
2846            (Some(3), Some("no 3")),
2847            (Some(3), Some("no 3")),
2848        ]);
2849
2850        let (input_values, expected) =
2851            create_input_and_expected::<Int32Array, StringArray, _, _>([
2852                (1, None), // No match in WHEN
2853                (2, Some("two")),
2854                (3, Some("three")),
2855                (3, Some("three")),
2856                (2, Some("two")),
2857                (3, Some("three")),
2858                (5, None), // No match in WHEN
2859                (5, None), // No match in WHEN
2860                (3, Some("three")),
2861                (5, None), // No match in WHEN
2862            ]);
2863
2864        test_lookup_eval_with_and_without_else(
2865            &lookup_map,
2866            Arc::new(input_values),
2867            expected,
2868        );
2869    }
2870
2871    #[test]
2872    fn test_case_when_literal_lookup_f32_to_string_with_special_values_and_duplicate_cases()
2873     {
2874        let lookup_map = create_lookup([
2875            (Some(4.0), Some("four point zero")),
2876            (Some(f32::NAN), Some("NaN")),
2877            (Some(3.2), Some("three point two")),
2878            // Duplicate case to make sure it is not used
2879            (Some(f32::NAN), Some("should not use this NaN branch")),
2880            (Some(f32::INFINITY), Some("Infinity")),
2881            (Some(0.0), Some("zero")),
2882            // Duplicate case to make sure it is not used
2883            (
2884                Some(f32::INFINITY),
2885                Some("should not use this Infinity branch"),
2886            ),
2887            (Some(1.1), Some("one point one")),
2888        ]);
2889
2890        let (input_values, expected) =
2891            create_input_and_expected::<Float32Array, StringArray, _, _>([
2892                (1.1, Some("one point one")),
2893                (f32::NAN, Some("NaN")),
2894                (3.2, Some("three point two")),
2895                (3.2, Some("three point two")),
2896                (0.0, Some("zero")),
2897                (f32::INFINITY, Some("Infinity")),
2898                (3.2, Some("three point two")),
2899                (f32::NEG_INFINITY, None), // No match in WHEN
2900                (f32::NEG_INFINITY, None), // No match in WHEN
2901                (3.2, Some("three point two")),
2902                (-0.0, None), // No match in WHEN
2903            ]);
2904
2905        test_lookup_eval_with_and_without_else(
2906            &lookup_map,
2907            Arc::new(input_values),
2908            expected,
2909        );
2910    }
2911
2912    #[test]
2913    fn test_case_when_literal_lookup_f16_to_string_with_special_values() {
2914        let lookup_map = create_lookup([
2915            (
2916                ScalarValue::Float16(Some(f16::from_f32(3.2))),
2917                Some("3 dot 2"),
2918            ),
2919            (ScalarValue::Float16(Some(f16::NAN)), Some("NaN")),
2920            (
2921                ScalarValue::Float16(Some(f16::from_f32(17.4))),
2922                Some("17 dot 4"),
2923            ),
2924            (ScalarValue::Float16(Some(f16::INFINITY)), Some("Infinity")),
2925            (ScalarValue::Float16(Some(f16::ZERO)), Some("zero")),
2926        ]);
2927
2928        let (input_values, expected) =
2929            create_input_and_expected::<Float16Array, StringArray, _, _>([
2930                (f16::from_f32(3.2), Some("3 dot 2")),
2931                (f16::NAN, Some("NaN")),
2932                (f16::from_f32(17.4), Some("17 dot 4")),
2933                (f16::from_f32(17.4), Some("17 dot 4")),
2934                (f16::INFINITY, Some("Infinity")),
2935                (f16::from_f32(17.4), Some("17 dot 4")),
2936                (f16::NEG_INFINITY, None), // No match in WHEN
2937                (f16::NEG_INFINITY, None), // No match in WHEN
2938                (f16::from_f32(17.4), Some("17 dot 4")),
2939                (f16::NEG_ZERO, None), // No match in WHEN
2940            ]);
2941
2942        test_lookup_eval_with_and_without_else(
2943            &lookup_map,
2944            Arc::new(input_values),
2945            expected,
2946        );
2947    }
2948
2949    #[test]
2950    fn test_case_when_literal_lookup_f32_to_string_with_special_values() {
2951        let lookup_map = create_lookup([
2952            (3.2, Some("3 dot 2")),
2953            (f32::NAN, Some("NaN")),
2954            (17.4, Some("17 dot 4")),
2955            (f32::INFINITY, Some("Infinity")),
2956            (f32::ZERO, Some("zero")),
2957        ]);
2958
2959        let (input_values, expected) =
2960            create_input_and_expected::<Float32Array, StringArray, _, _>([
2961                (3.2, Some("3 dot 2")),
2962                (f32::NAN, Some("NaN")),
2963                (17.4, Some("17 dot 4")),
2964                (17.4, Some("17 dot 4")),
2965                (f32::INFINITY, Some("Infinity")),
2966                (17.4, Some("17 dot 4")),
2967                (f32::NEG_INFINITY, None), // No match in WHEN
2968                (f32::NEG_INFINITY, None), // No match in WHEN
2969                (17.4, Some("17 dot 4")),
2970                (-0.0, None), // No match in WHEN
2971            ]);
2972
2973        test_lookup_eval_with_and_without_else(
2974            &lookup_map,
2975            Arc::new(input_values),
2976            expected,
2977        );
2978    }
2979
2980    #[test]
2981    fn test_case_when_literal_lookup_f64_to_string_with_special_values() {
2982        let lookup_map = create_lookup([
2983            (3.2, Some("3 dot 2")),
2984            (f64::NAN, Some("NaN")),
2985            (17.4, Some("17 dot 4")),
2986            (f64::INFINITY, Some("Infinity")),
2987            (f64::ZERO, Some("zero")),
2988        ]);
2989
2990        let (input_values, expected) =
2991            create_input_and_expected::<Float64Array, StringArray, _, _>([
2992                (3.2, Some("3 dot 2")),
2993                (f64::NAN, Some("NaN")),
2994                (17.4, Some("17 dot 4")),
2995                (17.4, Some("17 dot 4")),
2996                (f64::INFINITY, Some("Infinity")),
2997                (17.4, Some("17 dot 4")),
2998                (f64::NEG_INFINITY, None), // No match in WHEN
2999                (f64::NEG_INFINITY, None), // No match in WHEN
3000                (17.4, Some("17 dot 4")),
3001                (-0.0, None), // No match in WHEN
3002            ]);
3003
3004        test_lookup_eval_with_and_without_else(
3005            &lookup_map,
3006            Arc::new(input_values),
3007            expected,
3008        );
3009    }
3010
3011    // Test that we don't lose the decimal precision and scale info
3012    #[test]
3013    fn test_decimal_with_non_default_precision_and_scale() {
3014        let lookup_map = create_lookup([
3015            (ScalarValue::Decimal32(Some(4), 3, 2), Some("four")),
3016            (ScalarValue::Decimal32(Some(2), 3, 2), Some("two")),
3017            (ScalarValue::Decimal32(Some(3), 3, 2), Some("three")),
3018            (ScalarValue::Decimal32(Some(1), 3, 2), Some("one")),
3019        ]);
3020
3021        let (input_values, expected) =
3022            create_input_and_expected::<Decimal32Array, StringArray, _, _>([
3023                (1, Some("one")),
3024                (2, Some("two")),
3025                (3, Some("three")),
3026                (3, Some("three")),
3027                (2, Some("two")),
3028                (3, Some("three")),
3029                (5, None), // No match in WHEN
3030                (5, None), // No match in WHEN
3031                (3, Some("three")),
3032                (5, None), // No match in WHEN
3033            ]);
3034
3035        let input_values = input_values
3036            .with_precision_and_scale(3, 2)
3037            .expect("must be able to set precision and scale");
3038
3039        test_lookup_eval_with_and_without_else(
3040            &lookup_map,
3041            Arc::new(input_values),
3042            expected,
3043        );
3044    }
3045
3046    // Test that we don't lose the timezone info
3047    #[test]
3048    fn test_timestamp_with_non_default_timezone() {
3049        let timezone: Option<Arc<str>> = Some("-10:00".into());
3050        let lookup_map = create_lookup([
3051            (
3052                ScalarValue::TimestampMillisecond(Some(4), timezone.clone()),
3053                Some("four"),
3054            ),
3055            (
3056                ScalarValue::TimestampMillisecond(Some(2), timezone.clone()),
3057                Some("two"),
3058            ),
3059            (
3060                ScalarValue::TimestampMillisecond(Some(3), timezone.clone()),
3061                Some("three"),
3062            ),
3063            (
3064                ScalarValue::TimestampMillisecond(Some(1), timezone.clone()),
3065                Some("one"),
3066            ),
3067        ]);
3068
3069        let (input_values, expected) =
3070            create_input_and_expected::<TimestampMillisecondArray, StringArray, _, _>([
3071                (1, Some("one")),
3072                (2, Some("two")),
3073                (3, Some("three")),
3074                (3, Some("three")),
3075                (2, Some("two")),
3076                (3, Some("three")),
3077                (5, None), // No match in WHEN
3078                (5, None), // No match in WHEN
3079                (3, Some("three")),
3080                (5, None), // No match in WHEN
3081            ]);
3082
3083        let input_values = input_values.with_timezone_opt(timezone);
3084
3085        test_lookup_eval_with_and_without_else(
3086            &lookup_map,
3087            Arc::new(input_values),
3088            expected,
3089        );
3090    }
3091
3092    #[test]
3093    fn test_with_strings_to_int32() {
3094        let lookup_map = create_lookup([
3095            (Some("why"), Some(42)),
3096            (Some("what"), Some(22)),
3097            (Some("when"), Some(17)),
3098        ]);
3099
3100        let (input_values, expected) =
3101            create_input_and_expected::<StringArray, Int32Array, _, _>([
3102                (Some("why"), Some(42)),
3103                (Some("5"), None), // No match in WHEN
3104                (None, None), // None cases are never match in CASE <expr> WHEN <value> syntax
3105                (Some("what"), Some(22)),
3106                (None, None), // None cases are never match in CASE <expr> WHEN <value> syntax
3107                (None, None), // None cases are never match in CASE <expr> WHEN <value> syntax
3108                (Some("what"), Some(22)),
3109                (Some("5"), None), // No match in WHEN
3110            ]);
3111
3112        let input_values = Arc::new(input_values) as ArrayRef;
3113
3114        // Testing without ELSE should fallback to None
3115        test_case_when_literal_lookup(
3116            Arc::clone(&input_values),
3117            &lookup_map,
3118            None,
3119            Arc::new(expected.clone()),
3120        );
3121
3122        // Testing with Else
3123        let else_value = 101;
3124
3125        // Changing each expected None to be fallback
3126        let expected_with_else = expected
3127            .iter()
3128            .map(|item| item.unwrap_or(else_value))
3129            .map(Some)
3130            .collect::<Int32Array>();
3131
3132        // Test case
3133        test_case_when_literal_lookup(
3134            input_values,
3135            &lookup_map,
3136            Some(ScalarValue::Int32(Some(else_value))),
3137            Arc::new(expected_with_else),
3138        );
3139    }
3140
3141    /// Reproduces https://github.com/apache/datafusion/issues/22173
3142    ///
3143    /// Nested self-referential CASE chains (common in rewrite-style projections)
3144    /// should not cause exponential hashing work during physical planning.
3145    #[test]
3146    fn nested_self_referential_case_hash_stays_bounded() -> Result<()> {
3147        use std::hash::Hasher;
3148
3149        #[derive(Default)]
3150        struct CountingHasher {
3151            write_calls: usize,
3152            bytes_written: usize,
3153        }
3154
3155        impl Hasher for CountingHasher {
3156            fn finish(&self) -> u64 {
3157                0
3158            }
3159
3160            fn write(&mut self, bytes: &[u8]) {
3161                self.write_calls += 1;
3162                self.bytes_written += bytes.len();
3163            }
3164        }
3165
3166        let schema =
3167            Arc::new(Schema::new(vec![Field::new("kind", DataType::Utf8, true)]));
3168
3169        let kind = col("kind", &schema)?;
3170        let mut label = Arc::clone(&kind);
3171
3172        let num_levels = 18;
3173        for idx in 0..num_levels {
3174            let predicate = Arc::new(BinaryExpr::new(
3175                Arc::clone(&kind),
3176                Operator::Eq,
3177                lit(idx.to_string()),
3178            )) as Arc<dyn PhysicalExpr>;
3179
3180            label = case(None, vec![(predicate, lit("label"))], Some(label))?;
3181        }
3182
3183        let mut hasher = CountingHasher::default();
3184        label.hash(&mut hasher);
3185
3186        assert!(
3187            hasher.write_calls < 50_000,
3188            "hashing nested CASE expression took {} hasher writes and {} bytes",
3189            hasher.write_calls,
3190            hasher.bytes_written
3191        );
3192
3193        Ok(())
3194    }
3195}