datafusion_physical_expr/expressions/
case.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18mod literal_lookup_table;
19
20use super::{Column, Literal};
21use crate::PhysicalExpr;
22use crate::expressions::{lit, try_cast};
23use arrow::array::*;
24use arrow::compute::kernels::zip::zip;
25use arrow::compute::{
26    FilterBuilder, FilterPredicate, is_not_null, not, nullif, prep_null_mask_filter,
27};
28use arrow::datatypes::{DataType, Schema, UInt32Type, UnionMode};
29use arrow::error::ArrowError;
30use datafusion_common::cast::as_boolean_array;
31use datafusion_common::{
32    DataFusionError, Result, ScalarValue, assert_or_internal_err, exec_err,
33    internal_datafusion_err, internal_err,
34};
35use datafusion_expr::ColumnarValue;
36use indexmap::{IndexMap, IndexSet};
37use std::borrow::Cow;
38use std::hash::Hash;
39use std::{any::Any, sync::Arc};
40
41use crate::expressions::case::literal_lookup_table::LiteralLookupTable;
42use arrow::compute::kernels::merge::{MergeIndex, merge, merge_n};
43use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
44use datafusion_physical_expr_common::datum::compare_with_eq;
45use itertools::Itertools;
46use std::fmt::{Debug, Formatter};
47
48pub(super) type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
49
50#[derive(Debug, Hash, PartialEq, Eq)]
51enum EvalMethod {
52    /// CASE WHEN condition THEN result
53    ///      [WHEN ...]
54    ///      [ELSE result]
55    /// END
56    NoExpression(ProjectedCaseBody),
57    /// CASE expression
58    ///     WHEN value THEN result
59    ///     [WHEN ...]
60    ///     [ELSE result]
61    /// END
62    WithExpression(ProjectedCaseBody),
63    /// This is a specialization for a specific use case where we can take a fast path
64    /// for expressions that are infallible and can be cheaply computed for the entire
65    /// record batch rather than just for the rows where the predicate is true.
66    ///
67    /// CASE WHEN condition THEN column [ELSE NULL] END
68    InfallibleExprOrNull,
69    /// This is a specialization for a specific use case where we can take a fast path
70    /// if there is just one when/then pair and both the `then` and `else` expressions
71    /// are literal values
72    /// CASE WHEN condition THEN literal ELSE literal END
73    ScalarOrScalar,
74    /// This is a specialization for a specific use case where we can take a fast path
75    /// if there is just one when/then pair and both the `then` and `else` are expressions
76    ///
77    /// CASE WHEN condition THEN expression ELSE expression END
78    ExpressionOrExpression(ProjectedCaseBody),
79
80    /// This is a specialization for [`EvalMethod::WithExpression`] when the value and results are literals
81    ///
82    /// See [`LiteralLookupTable`] for more details
83    WithExprScalarLookupTable(LiteralLookupTable),
84}
85
86/// Implementing hash so we can use `derive` on [`EvalMethod`].
87///
88/// not implementing actual [`Hash`] as it is not dyn compatible so we cannot implement it for
89/// `dyn` [`literal_lookup_table::WhenLiteralIndexMap`].
90///
91/// So implementing empty hash is still valid as the data is derived from `PhysicalExpr` s which are already hashed
92impl Hash for LiteralLookupTable {
93    fn hash<H: std::hash::Hasher>(&self, _state: &mut H) {}
94}
95
96/// Implementing Equal so we can use `derive` on [`EvalMethod`].
97///
98/// not implementing actual [`PartialEq`] as it is not dyn compatible so we cannot implement it for
99/// `dyn` [`literal_lookup_table::WhenLiteralIndexMap`].
100///
101/// So we always return true as the data is derived from `PhysicalExpr` s which are already compared
102impl PartialEq for LiteralLookupTable {
103    fn eq(&self, _other: &Self) -> bool {
104        true
105    }
106}
107
108impl Eq for LiteralLookupTable {}
109
110/// The body of a CASE expression which consists of an optional base expression, the "when/then"
111/// branches and an optional "else" branch.
112#[derive(Debug, Hash, PartialEq, Eq)]
113struct CaseBody {
114    /// Optional base expression that can be compared to literal values in the "when" expressions
115    expr: Option<Arc<dyn PhysicalExpr>>,
116    /// One or more when/then expressions
117    when_then_expr: Vec<WhenThen>,
118    /// Optional "else" expression
119    else_expr: Option<Arc<dyn PhysicalExpr>>,
120}
121
122impl CaseBody {
123    /// Derives a [ProjectedCaseBody] from this [CaseBody].
124    fn project(&self) -> Result<ProjectedCaseBody> {
125        // Determine the set of columns that are used in all the expressions of the case body.
126        let mut used_column_indices = IndexSet::<usize>::new();
127        let mut collect_column_indices = |expr: &Arc<dyn PhysicalExpr>| {
128            expr.apply(|expr| {
129                if let Some(column) = expr.as_any().downcast_ref::<Column>() {
130                    used_column_indices.insert(column.index());
131                }
132                Ok(TreeNodeRecursion::Continue)
133            })
134            .expect("Closure cannot fail");
135        };
136
137        if let Some(e) = &self.expr {
138            collect_column_indices(e);
139        }
140        self.when_then_expr.iter().for_each(|(w, t)| {
141            collect_column_indices(w);
142            collect_column_indices(t);
143        });
144        if let Some(e) = &self.else_expr {
145            collect_column_indices(e);
146        }
147
148        // Construct a mapping from the original column index to the projected column index.
149        let column_index_map = used_column_indices
150            .iter()
151            .enumerate()
152            .map(|(projected, original)| (*original, projected))
153            .collect::<IndexMap<usize, usize>>();
154
155        // Construct the projected body by rewriting each expression from the original body
156        // using the column index mapping.
157        let project = |expr: &Arc<dyn PhysicalExpr>| -> Result<Arc<dyn PhysicalExpr>> {
158            Arc::clone(expr)
159                .transform_down(|e| {
160                    if let Some(column) = e.as_any().downcast_ref::<Column>() {
161                        let original = column.index();
162                        let projected = *column_index_map.get(&original).unwrap();
163                        if projected != original {
164                            return Ok(Transformed::yes(Arc::new(Column::new(
165                                column.name(),
166                                projected,
167                            ))));
168                        }
169                    }
170                    Ok(Transformed::no(e))
171                })
172                .map(|t| t.data)
173        };
174
175        let projected_body = CaseBody {
176            expr: self.expr.as_ref().map(project).transpose()?,
177            when_then_expr: self
178                .when_then_expr
179                .iter()
180                .map(|(e, t)| Ok((project(e)?, project(t)?)))
181                .collect::<Result<Vec<_>>>()?,
182            else_expr: self.else_expr.as_ref().map(project).transpose()?,
183        };
184
185        // Construct the projection vector
186        let projection = column_index_map
187            .iter()
188            .sorted_by_key(|(_, v)| **v)
189            .map(|(k, _)| *k)
190            .collect::<Vec<_>>();
191
192        Ok(ProjectedCaseBody {
193            projection,
194            body: projected_body,
195        })
196    }
197}
198
199/// A derived case body that can be used to evaluate a case expression after projecting
200/// record batches using a projection vector.
201///
202/// This is used to avoid filtering columns that are not used in the
203/// input `RecordBatch` when progressively evaluating a `CASE` expression's
204/// remainder batches. Filtering these columns is wasteful since for a record
205/// batch of `n` rows, filtering requires at worst a copy of `n - 1` values
206/// per array. If these filtered values will never be accessed, the time spent
207/// producing them is better avoided.
208///
209/// For example, if we are evaluating the following case expression that
210/// only references columns B and D:
211///
212/// ```sql
213/// SELECT CASE WHEN B > 10 THEN D ELSE NULL END FROM (VALUES (...)) T(A, B, C, D)
214/// ```
215///
216/// Of the 4 input columns `[A, B, C, D]`, the `CASE` expression only access `B` and `D`.
217/// Filtering `A` and `C` would be unnecessary and wasteful.
218///
219/// If we only retain columns `B` and `D` using `RecordBatch::project` and the projection vector
220/// `[1, 3]`, the indices of these two columns will change to `[0, 1]`. To evaluate the
221/// case expression, it will need to be rewritten from `CASE WHEN B@1 > 10 THEN D@3 ELSE NULL END`
222/// to `CASE WHEN B@0 > 10 THEN D@1 ELSE NULL END`.
223///
224/// The projection vector and the rewritten expression (which only differs from the original in
225/// column reference indices) are held in a `ProjectedCaseBody`.
226#[derive(Debug, Hash, PartialEq, Eq)]
227struct ProjectedCaseBody {
228    projection: Vec<usize>,
229    body: CaseBody,
230}
231
232/// The CASE expression is similar to a series of nested if/else and there are two forms that
233/// can be used. The first form consists of a series of boolean "when" expressions with
234/// corresponding "then" expressions, and an optional "else" expression.
235///
236/// CASE WHEN condition THEN result
237///      [WHEN ...]
238///      [ELSE result]
239/// END
240///
241/// The second form uses a base expression and then a series of "when" clauses that match on a
242/// literal value.
243///
244/// CASE expression
245///     WHEN value THEN result
246///     [WHEN ...]
247///     [ELSE result]
248/// END
249#[derive(Debug, Hash, PartialEq, Eq)]
250pub struct CaseExpr {
251    /// The case expression body
252    body: CaseBody,
253    /// Evaluation method to use
254    eval_method: EvalMethod,
255}
256
257impl std::fmt::Display for CaseExpr {
258    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
259        write!(f, "CASE ")?;
260        if let Some(e) = &self.body.expr {
261            write!(f, "{e} ")?;
262        }
263        for (w, t) in &self.body.when_then_expr {
264            write!(f, "WHEN {w} THEN {t} ")?;
265        }
266        if let Some(e) = &self.body.else_expr {
267            write!(f, "ELSE {e} ")?;
268        }
269        write!(f, "END")
270    }
271}
272
273/// This is a specialization for a specific use case where we can take a fast path
274/// for expressions that are infallible and can be cheaply computed for the entire
275/// record batch rather than just for the rows where the predicate is true. For now,
276/// this is limited to use with Column expressions but could potentially be used for other
277/// expressions in the future
278fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) -> bool {
279    expr.as_any().is::<Column>()
280}
281
282/// Creates a [FilterPredicate] from a boolean array.
283fn create_filter(predicate: &BooleanArray, optimize: bool) -> FilterPredicate {
284    let mut filter_builder = FilterBuilder::new(predicate);
285    if optimize {
286        // Always optimize the filter since we use them multiple times.
287        filter_builder = filter_builder.optimize();
288    }
289    filter_builder.build()
290}
291
292fn multiple_arrays(data_type: &DataType) -> bool {
293    match data_type {
294        DataType::Struct(fields) => {
295            fields.len() > 1
296                || fields.len() == 1 && multiple_arrays(fields[0].data_type())
297        }
298        DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
299        _ => false,
300    }
301}
302
303// This should be removed when https://github.com/apache/arrow-rs/pull/8693
304// is merged and becomes available.
305fn filter_record_batch(
306    record_batch: &RecordBatch,
307    filter: &FilterPredicate,
308) -> std::result::Result<RecordBatch, ArrowError> {
309    let filtered_columns = record_batch
310        .columns()
311        .iter()
312        .map(|a| filter_array(a, filter))
313        .collect::<std::result::Result<Vec<_>, _>>()?;
314    // SAFETY: since we start from a valid RecordBatch, there's no need to revalidate the schema
315    // since the set of columns has not changed.
316    // The input column arrays all had the same length (since they're coming from a valid RecordBatch)
317    // and the filtering them with the same filter will produces a new set of arrays with identical
318    // lengths.
319    unsafe {
320        Ok(RecordBatch::new_unchecked(
321            record_batch.schema(),
322            filtered_columns,
323            filter.count(),
324        ))
325    }
326}
327
328// This function exists purely to be able to use the same call style
329// for `filter_record_batch` and `filter_array` at the point of use.
330// When https://github.com/apache/arrow-rs/pull/8693 is available, replace
331// both with method calls on `FilterPredicate`.
332#[inline(always)]
333fn filter_array(
334    array: &dyn Array,
335    filter: &FilterPredicate,
336) -> std::result::Result<ArrayRef, ArrowError> {
337    filter.filter(array)
338}
339
340/// An index into the partial results array that's more compact than `usize`.
341///
342/// `u32::MAX` is reserved as a special 'none' value. This is used instead of
343/// `Option` to keep the array of indices as compact as possible.
344#[derive(Copy, Clone, PartialEq, Eq)]
345struct PartialResultIndex {
346    index: u32,
347}
348
349const NONE_VALUE: u32 = u32::MAX;
350
351impl PartialResultIndex {
352    /// Returns the 'none' placeholder value.
353    fn none() -> Self {
354        Self { index: NONE_VALUE }
355    }
356
357    fn zero() -> Self {
358        Self { index: 0 }
359    }
360
361    /// Creates a new partial result index.
362    ///
363    /// If the provided value is greater than or equal to `u32::MAX`
364    /// an error will be returned.
365    fn try_new(index: usize) -> Result<Self> {
366        let Ok(index) = u32::try_from(index) else {
367            return internal_err!("Partial result index exceeds limit");
368        };
369
370        assert_or_internal_err!(
371            index != NONE_VALUE,
372            "Partial result index exceeds limit"
373        );
374
375        Ok(Self { index })
376    }
377
378    /// Determines if this index is the 'none' placeholder value or not.
379    fn is_none(&self) -> bool {
380        self.index == NONE_VALUE
381    }
382}
383
384impl MergeIndex for PartialResultIndex {
385    /// Returns `Some(index)` if this value is not the 'none' placeholder, `None` otherwise.
386    fn index(&self) -> Option<usize> {
387        if self.is_none() {
388            None
389        } else {
390            Some(self.index as usize)
391        }
392    }
393}
394
395impl Debug for PartialResultIndex {
396    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
397        if self.is_none() {
398            write!(f, "null")
399        } else {
400            write!(f, "{}", self.index)
401        }
402    }
403}
404
405enum ResultState {
406    /// The final result is an array containing only null values.
407    Empty,
408    /// The final result needs to be computed by merging the data in `arrays`.
409    Partial {
410        // A `Vec` of partial results that should be merged.
411        // `partial_result_indices` contains indexes into this vec.
412        arrays: Vec<ArrayRef>,
413        // Indicates per result row from which array in `partial_results` a value should be taken.
414        indices: Vec<PartialResultIndex>,
415    },
416    /// A single branch matched all input rows. When creating the final result, no further merging
417    /// of partial results is necessary.
418    Complete(ColumnarValue),
419}
420
421/// A builder for constructing result arrays for CASE expressions.
422///
423/// Rather than building a monolithic array containing all results, it maintains a set of
424/// partial result arrays and a mapping that indicates for each row which partial array
425/// contains the result value for that row.
426///
427/// On finish(), the builder will merge all partial results into a single array if necessary.
428/// If all rows evaluated to the same array, that array can be returned directly without
429/// any merging overhead.
430struct ResultBuilder {
431    data_type: DataType,
432    /// The number of rows in the final result.
433    row_count: usize,
434    state: ResultState,
435}
436
437impl ResultBuilder {
438    /// Creates a new ResultBuilder that will produce arrays of the given data type.
439    ///
440    /// The `row_count` parameter indicates the number of rows in the final result.
441    fn new(data_type: &DataType, row_count: usize) -> Self {
442        Self {
443            data_type: data_type.clone(),
444            row_count,
445            state: ResultState::Empty,
446        }
447    }
448
449    /// Adds a result for one branch of the case expression.
450    ///
451    /// `row_indices` should be a [UInt32Array] containing [RecordBatch] relative row indices
452    /// for which `value` contains result values.
453    ///
454    /// If `value` is a scalar, the scalar value will be used as the value for each row in `row_indices`.
455    ///
456    /// If `value` is an array, the values from the array and the indices from `row_indices` will be
457    /// processed pairwise. The lengths of `value` and `row_indices` must match.
458    ///
459    /// The diagram below shows a situation where a when expression matched rows 1 and 4 of the
460    /// record batch. The then expression produced the value array `[A, D]`.
461    /// After adding this result, the result array will have been added to `partial arrays` and
462    /// `partial indices` will have been updated at indexes `1` and `4`.
463    ///
464    /// ```text
465    ///  ┌─────────┐     ┌─────────┐┌───────────┐                            ┌─────────┐┌───────────┐
466    ///  │    C    │     │ 0: None ││┌ 0 ──────┐│                            │ 0: None ││┌ 0 ──────┐│
467    ///  ├─────────┤     ├─────────┤││    A    ││                            ├─────────┤││    A    ││
468    ///  │    D    │     │ 1: None ││└─────────┘│                            │ 1:  2   ││└─────────┘│
469    ///  └─────────┘     ├─────────┤│┌ 1 ──────┐│   add_branch_result(       ├─────────┤│┌ 1 ──────┐│
470    ///   matching       │ 2:  0   │││    B    ││     row indices,           │ 2:  0   │││    B    ││
471    /// 'then' values    ├─────────┤│└─────────┘│     value                  ├─────────┤│└─────────┘│
472    ///                  │ 3: None ││           │   )                        │ 3: None ││┌ 2 ──────┐│
473    ///  ┌─────────┐     ├─────────┤│           │ ─────────────────────────▶ ├─────────┤││    C    ││
474    ///  │    1    │     │ 4: None ││           │                            │ 4:  2   ││├─────────┤│
475    ///  ├─────────┤     ├─────────┤│           │                            ├─────────┤││    D    ││
476    ///  │    4    │     │ 5:  1   ││           │                            │ 5:  1   ││└─────────┘│
477    ///  └─────────┘     └─────────┘└───────────┘                            └─────────┘└───────────┘
478    /// row indices        partial     partial                                 partial     partial
479    ///                    indices     arrays                                  indices     arrays
480    /// ```
481    fn add_branch_result(
482        &mut self,
483        row_indices: &ArrayRef,
484        value: ColumnarValue,
485    ) -> Result<()> {
486        match value {
487            ColumnarValue::Array(a) => {
488                if a.len() != row_indices.len() {
489                    internal_err!("Array length must match row indices length")
490                } else if row_indices.len() == self.row_count {
491                    self.set_complete_result(ColumnarValue::Array(a))
492                } else {
493                    self.add_partial_result(row_indices, a)
494                }
495            }
496            ColumnarValue::Scalar(s) => {
497                if row_indices.len() == self.row_count {
498                    self.set_complete_result(ColumnarValue::Scalar(s))
499                } else {
500                    self.add_partial_result(
501                        row_indices,
502                        s.to_array_of_size(row_indices.len())?,
503                    )
504                }
505            }
506        }
507    }
508
509    /// Adds a partial result array.
510    ///
511    /// This method adds the given array data as a partial result and updates the index mapping
512    /// to indicate that the specified rows should take their values from this array.
513    /// The partial results will be merged into a single array when finish() is called.
514    fn add_partial_result(
515        &mut self,
516        row_indices: &ArrayRef,
517        row_values: ArrayRef,
518    ) -> Result<()> {
519        assert_or_internal_err!(
520            row_indices.null_count() == 0,
521            "Row indices must not contain nulls"
522        );
523
524        match &mut self.state {
525            ResultState::Empty => {
526                let array_index = PartialResultIndex::zero();
527                let mut indices = vec![PartialResultIndex::none(); self.row_count];
528                for row_ix in row_indices.as_primitive::<UInt32Type>().values().iter() {
529                    indices[*row_ix as usize] = array_index;
530                }
531
532                self.state = ResultState::Partial {
533                    arrays: vec![row_values],
534                    indices,
535                };
536
537                Ok(())
538            }
539            ResultState::Partial { arrays, indices } => {
540                let array_index = PartialResultIndex::try_new(arrays.len())?;
541
542                arrays.push(row_values);
543
544                for row_ix in row_indices.as_primitive::<UInt32Type>().values().iter() {
545                    // This is check is only active for debug config because the callers of this method,
546                    // `case_when_with_expr` and `case_when_no_expr`, already ensure that
547                    // they only calculate a value for each row at most once.
548                    #[cfg(debug_assertions)]
549                    assert_or_internal_err!(
550                        indices[*row_ix as usize].is_none(),
551                        "Duplicate value for row {}",
552                        *row_ix
553                    );
554
555                    indices[*row_ix as usize] = array_index;
556                }
557                Ok(())
558            }
559            ResultState::Complete(_) => internal_err!(
560                "Cannot add a partial result when complete result is already set"
561            ),
562        }
563    }
564
565    /// Sets a result that applies to all rows.
566    ///
567    /// This is an optimization for cases where all rows evaluate to the same result.
568    /// When a complete result is set, the builder will return it directly from finish()
569    /// without any merging overhead.
570    fn set_complete_result(&mut self, value: ColumnarValue) -> Result<()> {
571        match &self.state {
572            ResultState::Empty => {
573                self.state = ResultState::Complete(value);
574                Ok(())
575            }
576            ResultState::Partial { .. } => {
577                internal_err!(
578                    "Cannot set a complete result when there are already partial results"
579                )
580            }
581            ResultState::Complete(_) => internal_err!("Complete result already set"),
582        }
583    }
584
585    /// Finishes building the result and returns the final array.
586    fn finish(self) -> Result<ColumnarValue> {
587        match self.state {
588            ResultState::Empty => {
589                // No complete result and no partial results.
590                // This can happen for case expressions with no else branch where no rows
591                // matched.
592                Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
593                    &self.data_type,
594                )?))
595            }
596            ResultState::Partial { arrays, indices } => {
597                // Merge partial results into a single array.
598                let array_refs = arrays.iter().map(|a| a.as_ref()).collect::<Vec<_>>();
599                Ok(ColumnarValue::Array(merge_n(&array_refs, &indices)?))
600            }
601            ResultState::Complete(v) => {
602                // If we have a complete result, we can just return it.
603                Ok(v)
604            }
605        }
606    }
607}
608
609impl CaseExpr {
610    /// Create a new CASE WHEN expression
611    pub fn try_new(
612        expr: Option<Arc<dyn PhysicalExpr>>,
613        when_then_expr: Vec<WhenThen>,
614        else_expr: Option<Arc<dyn PhysicalExpr>>,
615    ) -> Result<Self> {
616        // normalize null literals to None in the else_expr (this already happens
617        // during SQL planning, but not necessarily for other use cases)
618        let else_expr = match &else_expr {
619            Some(e) => match e.as_any().downcast_ref::<Literal>() {
620                Some(lit) if lit.value().is_null() => None,
621                _ => else_expr,
622            },
623            _ => else_expr,
624        };
625
626        if when_then_expr.is_empty() {
627            return exec_err!("There must be at least one WHEN clause");
628        }
629
630        let body = CaseBody {
631            expr,
632            when_then_expr,
633            else_expr,
634        };
635
636        let eval_method = Self::find_best_eval_method(&body)?;
637
638        Ok(Self { body, eval_method })
639    }
640
641    fn find_best_eval_method(body: &CaseBody) -> Result<EvalMethod> {
642        if body.expr.is_some() {
643            if let Some(mapping) = LiteralLookupTable::maybe_new(body) {
644                return Ok(EvalMethod::WithExprScalarLookupTable(mapping));
645            }
646
647            return Ok(EvalMethod::WithExpression(body.project()?));
648        }
649
650        Ok(
651            if body.when_then_expr.len() == 1
652                && is_cheap_and_infallible(&(body.when_then_expr[0].1))
653                && body.else_expr.is_none()
654            {
655                EvalMethod::InfallibleExprOrNull
656            } else if body.when_then_expr.len() == 1
657                && body.when_then_expr[0].1.as_any().is::<Literal>()
658                && body.else_expr.is_some()
659                && body.else_expr.as_ref().unwrap().as_any().is::<Literal>()
660            {
661                EvalMethod::ScalarOrScalar
662            } else if body.when_then_expr.len() == 1 && body.else_expr.is_some() {
663                EvalMethod::ExpressionOrExpression(body.project()?)
664            } else {
665                EvalMethod::NoExpression(body.project()?)
666            },
667        )
668    }
669
670    /// Optional base expression that can be compared to literal values in the "when" expressions
671    pub fn expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
672        self.body.expr.as_ref()
673    }
674
675    /// One or more when/then expressions
676    pub fn when_then_expr(&self) -> &[WhenThen] {
677        &self.body.when_then_expr
678    }
679
680    /// Optional "else" expression
681    pub fn else_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
682        self.body.else_expr.as_ref()
683    }
684}
685
686impl CaseBody {
687    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
688        // since all then results have the same data type, we can choose any one as the
689        // return data type except for the null.
690        let mut data_type = DataType::Null;
691        for i in 0..self.when_then_expr.len() {
692            data_type = self.when_then_expr[i].1.data_type(input_schema)?;
693            if !data_type.equals_datatype(&DataType::Null) {
694                break;
695            }
696        }
697        // if all then results are null, we use data type of else expr instead if possible.
698        if data_type.equals_datatype(&DataType::Null)
699            && let Some(e) = &self.else_expr
700        {
701            data_type = e.data_type(input_schema)?;
702        }
703
704        Ok(data_type)
705    }
706
707    /// See [CaseExpr::case_when_with_expr].
708    fn case_when_with_expr(
709        &self,
710        batch: &RecordBatch,
711        return_type: &DataType,
712    ) -> Result<ColumnarValue> {
713        let mut result_builder = ResultBuilder::new(return_type, batch.num_rows());
714
715        // `remainder_rows` contains the indices of the rows that need to be evaluated
716        let mut remainder_rows: ArrayRef =
717            Arc::new(UInt32Array::from_iter_values(0..batch.num_rows() as u32));
718        // `remainder_batch` contains the rows themselves that need to be evaluated
719        let mut remainder_batch = Cow::Borrowed(batch);
720
721        // evaluate the base expression
722        let mut base_values = self
723            .expr
724            .as_ref()
725            .unwrap()
726            .evaluate(batch)?
727            .into_array(batch.num_rows())?;
728
729        // Fill in a result value already for rows where the base expression value is null
730        // Since each when expression is tested against the base expression using the equality
731        // operator, null base values can never match any when expression. `x = NULL` is falsy,
732        // for all possible values of `x`.
733        let base_null_count = base_values.logical_null_count();
734        if base_null_count > 0 {
735            // Use `is_not_null` since this is a cheap clone of the null buffer from 'base_value'.
736            // We already checked there are nulls, so we can be sure a new buffer will not be
737            // created.
738            let base_not_nulls = is_not_null(base_values.as_ref())?;
739            let base_all_null = base_null_count == remainder_batch.num_rows();
740
741            // If there is an else expression, use that as the default value for the null rows
742            // Otherwise the default `null` value from the result builder will be used.
743            if let Some(e) = &self.else_expr {
744                let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
745
746                if base_all_null {
747                    // All base values were null, so no need to filter
748                    let nulls_value = expr.evaluate(&remainder_batch)?;
749                    result_builder.add_branch_result(&remainder_rows, nulls_value)?;
750                } else {
751                    // Filter out the null rows and evaluate the else expression for those
752                    let nulls_filter = create_filter(&not(&base_not_nulls)?, true);
753                    let nulls_batch =
754                        filter_record_batch(&remainder_batch, &nulls_filter)?;
755                    let nulls_rows = filter_array(&remainder_rows, &nulls_filter)?;
756                    let nulls_value = expr.evaluate(&nulls_batch)?;
757                    result_builder.add_branch_result(&nulls_rows, nulls_value)?;
758                }
759            }
760
761            // All base values are null, so we can return early
762            if base_all_null {
763                return result_builder.finish();
764            }
765
766            // Remove the null rows from the remainder batch
767            let not_null_filter = create_filter(&base_not_nulls, true);
768            remainder_batch =
769                Cow::Owned(filter_record_batch(&remainder_batch, &not_null_filter)?);
770            remainder_rows = filter_array(&remainder_rows, &not_null_filter)?;
771            base_values = filter_array(&base_values, &not_null_filter)?;
772        }
773
774        // The types of case and when expressions will be coerced to match.
775        // We only need to check if the base_value is nested.
776        let base_value_is_nested = base_values.data_type().is_nested();
777
778        for i in 0..self.when_then_expr.len() {
779            // Evaluate the 'when' predicate for the remainder batch
780            // This results in a boolean array with the same length as the remaining number of rows
781            let when_expr = &self.when_then_expr[i].0;
782            let when_value = match when_expr.evaluate(&remainder_batch)? {
783                ColumnarValue::Array(a) => {
784                    compare_with_eq(&a, &base_values, base_value_is_nested)
785                }
786                ColumnarValue::Scalar(s) => {
787                    compare_with_eq(&s.to_scalar()?, &base_values, base_value_is_nested)
788                }
789            }?;
790
791            // `true_count` ignores `true` values where the validity bit is not set, so there's
792            // no need to call `prep_null_mask_filter`.
793            let when_true_count = when_value.true_count();
794
795            // If the 'when' predicate did not match any rows, continue to the next branch immediately
796            if when_true_count == 0 {
797                continue;
798            }
799
800            // If the 'when' predicate matched all remaining rows, there is no need to filter
801            if when_true_count == remainder_batch.num_rows() {
802                let then_expression = &self.when_then_expr[i].1;
803                let then_value = then_expression.evaluate(&remainder_batch)?;
804                result_builder.add_branch_result(&remainder_rows, then_value)?;
805                return result_builder.finish();
806            }
807
808            // Filter the remainder batch based on the 'when' value
809            // This results in a batch containing only the rows that need to be evaluated
810            // for the current branch
811            // Still no need to call `prep_null_mask_filter` since `create_filter` will already do
812            // this unconditionally.
813            let then_filter = create_filter(&when_value, true);
814            let then_batch = filter_record_batch(&remainder_batch, &then_filter)?;
815            let then_rows = filter_array(&remainder_rows, &then_filter)?;
816
817            let then_expression = &self.when_then_expr[i].1;
818            let then_value = then_expression.evaluate(&then_batch)?;
819            result_builder.add_branch_result(&then_rows, then_value)?;
820
821            // If this is the last 'when' branch and there is no 'else' expression, there's no
822            // point in calculating the remaining rows.
823            if self.else_expr.is_none() && i == self.when_then_expr.len() - 1 {
824                return result_builder.finish();
825            }
826
827            // Prepare the next when branch (or the else branch)
828            let next_selection = match when_value.null_count() {
829                0 => not(&when_value),
830                _ => {
831                    // `prep_null_mask_filter` is required to ensure the not operation treats nulls
832                    // as false
833                    not(&prep_null_mask_filter(&when_value))
834                }
835            }?;
836            let next_filter = create_filter(&next_selection, true);
837            remainder_batch =
838                Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?);
839            remainder_rows = filter_array(&remainder_rows, &next_filter)?;
840            base_values = filter_array(&base_values, &next_filter)?;
841        }
842
843        // If we reached this point, some rows were left unmatched.
844        // Check if those need to be evaluated using the 'else' expression.
845        if let Some(e) = &self.else_expr {
846            // keep `else_expr`'s data type and return type consistent
847            let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
848            let else_value = expr.evaluate(&remainder_batch)?;
849            result_builder.add_branch_result(&remainder_rows, else_value)?;
850        }
851
852        result_builder.finish()
853    }
854
855    /// See [CaseExpr::case_when_no_expr].
856    fn case_when_no_expr(
857        &self,
858        batch: &RecordBatch,
859        return_type: &DataType,
860    ) -> Result<ColumnarValue> {
861        let mut result_builder = ResultBuilder::new(return_type, batch.num_rows());
862
863        // `remainder_rows` contains the indices of the rows that need to be evaluated
864        let mut remainder_rows: ArrayRef =
865            Arc::new(UInt32Array::from_iter(0..batch.num_rows() as u32));
866        // `remainder_batch` contains the rows themselves that need to be evaluated
867        let mut remainder_batch = Cow::Borrowed(batch);
868
869        for i in 0..self.when_then_expr.len() {
870            // Evaluate the 'when' predicate for the remainder batch
871            // This results in a boolean array with the same length as the remaining number of rows
872            let when_predicate = &self.when_then_expr[i].0;
873            let when_value = when_predicate
874                .evaluate(&remainder_batch)?
875                .into_array(remainder_batch.num_rows())?;
876            let when_value = as_boolean_array(&when_value).map_err(|_| {
877                internal_datafusion_err!("WHEN expression did not return a BooleanArray")
878            })?;
879
880            // `true_count` ignores `true` values where the validity bit is not set, so there's
881            // no need to call `prep_null_mask_filter`.
882            let when_true_count = when_value.true_count();
883
884            // If the 'when' predicate did not match any rows, continue to the next branch immediately
885            if when_true_count == 0 {
886                continue;
887            }
888
889            // If the 'when' predicate matched all remaining rows, there is no need to filter
890            if when_true_count == remainder_batch.num_rows() {
891                let then_expression = &self.when_then_expr[i].1;
892                let then_value = then_expression.evaluate(&remainder_batch)?;
893                result_builder.add_branch_result(&remainder_rows, then_value)?;
894                return result_builder.finish();
895            }
896
897            // Filter the remainder batch based on the 'when' value
898            // This results in a batch containing only the rows that need to be evaluated
899            // for the current branch
900            // Still no need to call `prep_null_mask_filter` since `create_filter` will already do
901            // this unconditionally.
902            let then_filter = create_filter(when_value, true);
903            let then_batch = filter_record_batch(&remainder_batch, &then_filter)?;
904            let then_rows = filter_array(&remainder_rows, &then_filter)?;
905
906            let then_expression = &self.when_then_expr[i].1;
907            let then_value = then_expression.evaluate(&then_batch)?;
908            result_builder.add_branch_result(&then_rows, then_value)?;
909
910            // If this is the last 'when' branch and there is no 'else' expression, there's no
911            // point in calculating the remaining rows.
912            if self.else_expr.is_none() && i == self.when_then_expr.len() - 1 {
913                return result_builder.finish();
914            }
915
916            // Prepare the next when branch (or the else branch)
917            let next_selection = match when_value.null_count() {
918                0 => not(when_value),
919                _ => {
920                    // `prep_null_mask_filter` is required to ensure the not operation treats nulls
921                    // as false
922                    not(&prep_null_mask_filter(when_value))
923                }
924            }?;
925            let next_filter = create_filter(&next_selection, true);
926            remainder_batch =
927                Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?);
928            remainder_rows = filter_array(&remainder_rows, &next_filter)?;
929        }
930
931        // If we reached this point, some rows were left unmatched.
932        // Check if those need to be evaluated using the 'else' expression.
933        if let Some(e) = &self.else_expr {
934            // keep `else_expr`'s data type and return type consistent
935            let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
936            let else_value = expr.evaluate(&remainder_batch)?;
937            result_builder.add_branch_result(&remainder_rows, else_value)?;
938        }
939
940        result_builder.finish()
941    }
942
943    /// See [CaseExpr::expr_or_expr].
944    fn expr_or_expr(
945        &self,
946        batch: &RecordBatch,
947        when_value: &BooleanArray,
948    ) -> Result<ColumnarValue> {
949        let when_value = match when_value.null_count() {
950            0 => Cow::Borrowed(when_value),
951            _ => {
952                // `prep_null_mask_filter` is required to ensure null is treated as false
953                Cow::Owned(prep_null_mask_filter(when_value))
954            }
955        };
956
957        let optimize_filter = batch.num_columns() > 1
958            || (batch.num_columns() == 1 && multiple_arrays(batch.column(0).data_type()));
959
960        let when_filter = create_filter(&when_value, optimize_filter);
961        let then_batch = filter_record_batch(batch, &when_filter)?;
962        let then_value = self.when_then_expr[0].1.evaluate(&then_batch)?;
963
964        let else_selection = not(&when_value)?;
965        let else_filter = create_filter(&else_selection, optimize_filter);
966        let else_batch = filter_record_batch(batch, &else_filter)?;
967
968        // keep `else_expr`'s data type and return type consistent
969        let e = self.else_expr.as_ref().unwrap();
970        let return_type = self.data_type(&batch.schema())?;
971        let else_expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())
972            .unwrap_or_else(|_| Arc::clone(e));
973
974        let else_value = else_expr.evaluate(&else_batch)?;
975
976        Ok(ColumnarValue::Array(match (then_value, else_value) {
977            (ColumnarValue::Array(t), ColumnarValue::Array(e)) => {
978                merge(&when_value, &t, &e)
979            }
980            (ColumnarValue::Scalar(t), ColumnarValue::Array(e)) => {
981                merge(&when_value, &t.to_scalar()?, &e)
982            }
983            (ColumnarValue::Array(t), ColumnarValue::Scalar(e)) => {
984                merge(&when_value, &t, &e.to_scalar()?)
985            }
986            (ColumnarValue::Scalar(t), ColumnarValue::Scalar(e)) => {
987                merge(&when_value, &t.to_scalar()?, &e.to_scalar()?)
988            }
989        }?))
990    }
991}
992
993impl CaseExpr {
994    /// This function evaluates the form of CASE that matches an expression to fixed values.
995    ///
996    /// CASE expression
997    ///     WHEN value THEN result
998    ///     [WHEN ...]
999    ///     [ELSE result]
1000    /// END
1001    fn case_when_with_expr(
1002        &self,
1003        batch: &RecordBatch,
1004        projected: &ProjectedCaseBody,
1005    ) -> Result<ColumnarValue> {
1006        let return_type = self.data_type(&batch.schema())?;
1007        if projected.projection.len() < batch.num_columns() {
1008            let projected_batch = batch.project(&projected.projection)?;
1009            projected
1010                .body
1011                .case_when_with_expr(&projected_batch, &return_type)
1012        } else {
1013            self.body.case_when_with_expr(batch, &return_type)
1014        }
1015    }
1016
1017    /// This function evaluates the form of CASE where each WHEN expression is a boolean
1018    /// expression.
1019    ///
1020    /// CASE WHEN condition THEN result
1021    ///      [WHEN ...]
1022    ///      [ELSE result]
1023    /// END
1024    fn case_when_no_expr(
1025        &self,
1026        batch: &RecordBatch,
1027        projected: &ProjectedCaseBody,
1028    ) -> Result<ColumnarValue> {
1029        let return_type = self.data_type(&batch.schema())?;
1030        if projected.projection.len() < batch.num_columns() {
1031            let projected_batch = batch.project(&projected.projection)?;
1032            projected
1033                .body
1034                .case_when_no_expr(&projected_batch, &return_type)
1035        } else {
1036            self.body.case_when_no_expr(batch, &return_type)
1037        }
1038    }
1039
1040    /// This function evaluates the specialized case of:
1041    ///
1042    /// CASE WHEN condition THEN column
1043    ///      [ELSE NULL]
1044    /// END
1045    ///
1046    /// Note that this function is only safe to use for "then" expressions
1047    /// that are infallible because the expression will be evaluated for all
1048    /// rows in the input batch.
1049    fn case_column_or_null(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
1050        let when_expr = &self.body.when_then_expr[0].0;
1051        let then_expr = &self.body.when_then_expr[0].1;
1052
1053        match when_expr.evaluate(batch)? {
1054            // WHEN true --> column
1055            ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))) => {
1056                then_expr.evaluate(batch)
1057            }
1058            // WHEN [false | null] --> NULL
1059            ColumnarValue::Scalar(_) => {
1060                // return scalar NULL value
1061                ScalarValue::try_from(self.data_type(&batch.schema())?)
1062                    .map(ColumnarValue::Scalar)
1063            }
1064            // WHEN column --> column
1065            ColumnarValue::Array(bit_mask) => {
1066                let bit_mask = bit_mask
1067                    .as_any()
1068                    .downcast_ref::<BooleanArray>()
1069                    .expect("predicate should evaluate to a boolean array");
1070                // invert the bitmask
1071                let bit_mask = match bit_mask.null_count() {
1072                    0 => not(bit_mask)?,
1073                    _ => not(&prep_null_mask_filter(bit_mask))?,
1074                };
1075                match then_expr.evaluate(batch)? {
1076                    ColumnarValue::Array(array) => {
1077                        Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?))
1078                    }
1079                    ColumnarValue::Scalar(_) => {
1080                        internal_err!("expression did not evaluate to an array")
1081                    }
1082                }
1083            }
1084        }
1085    }
1086
1087    fn scalar_or_scalar(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
1088        let return_type = self.data_type(&batch.schema())?;
1089
1090        // evaluate when expression
1091        let when_value = self.body.when_then_expr[0].0.evaluate(batch)?;
1092        let when_value = when_value.into_array(batch.num_rows())?;
1093        let when_value = as_boolean_array(&when_value).map_err(|_| {
1094            internal_datafusion_err!("WHEN expression did not return a BooleanArray")
1095        })?;
1096
1097        // Treat 'NULL' as false value
1098        let when_value = match when_value.null_count() {
1099            0 => Cow::Borrowed(when_value),
1100            _ => Cow::Owned(prep_null_mask_filter(when_value)),
1101        };
1102
1103        // evaluate then_value
1104        let then_value = self.body.when_then_expr[0].1.evaluate(batch)?;
1105        let then_value = Scalar::new(then_value.into_array(1)?);
1106
1107        let Some(e) = &self.body.else_expr else {
1108            return internal_err!("expression did not evaluate to an array");
1109        };
1110        // keep `else_expr`'s data type and return type consistent
1111        let expr = try_cast(Arc::clone(e), &batch.schema(), return_type)?;
1112        let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?);
1113        Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?))
1114    }
1115
1116    fn expr_or_expr(
1117        &self,
1118        batch: &RecordBatch,
1119        projected: &ProjectedCaseBody,
1120    ) -> Result<ColumnarValue> {
1121        // evaluate when condition on batch
1122        let when_value = self.body.when_then_expr[0].0.evaluate(batch)?;
1123        // `num_rows == 1` is intentional to avoid expanding scalars.
1124        // If the `when_value` is effectively a scalar, the 'all true' and 'all false' checks
1125        // below will avoid incorrectly using the scalar as a merge/zip mask.
1126        let when_value = when_value.into_array(1)?;
1127        let when_value = as_boolean_array(&when_value).map_err(|e| {
1128            DataFusionError::Context(
1129                "WHEN expression did not return a BooleanArray".to_string(),
1130                Box::new(e),
1131            )
1132        })?;
1133
1134        let true_count = when_value.true_count();
1135        if true_count == when_value.len() {
1136            // All input rows are true, just call the 'then' expression
1137            self.body.when_then_expr[0].1.evaluate(batch)
1138        } else if true_count == 0 {
1139            // All input rows are false/null, just call the 'else' expression
1140            self.body.else_expr.as_ref().unwrap().evaluate(batch)
1141        } else if projected.projection.len() < batch.num_columns() {
1142            // The case expressions do not use all the columns of the input batch.
1143            // Project first to reduce time spent filtering.
1144            let projected_batch = batch.project(&projected.projection)?;
1145            projected.body.expr_or_expr(&projected_batch, when_value)
1146        } else {
1147            // All columns are used in the case expressions, so there is no need to project.
1148            self.body.expr_or_expr(batch, when_value)
1149        }
1150    }
1151
1152    fn with_lookup_table(
1153        &self,
1154        batch: &RecordBatch,
1155        lookup_table: &LiteralLookupTable,
1156    ) -> Result<ColumnarValue> {
1157        let expr = self.body.expr.as_ref().unwrap();
1158        let evaluated_expression = expr.evaluate(batch)?;
1159
1160        let is_scalar = matches!(evaluated_expression, ColumnarValue::Scalar(_));
1161        let evaluated_expression = evaluated_expression.to_array(1)?;
1162
1163        let values = lookup_table.map_keys_to_values(&evaluated_expression)?;
1164
1165        let result = if is_scalar {
1166            ColumnarValue::Scalar(ScalarValue::try_from_array(values.as_ref(), 0)?)
1167        } else {
1168            ColumnarValue::Array(values)
1169        };
1170
1171        Ok(result)
1172    }
1173}
1174
1175impl PhysicalExpr for CaseExpr {
1176    /// Return a reference to Any that can be used for down-casting
1177    fn as_any(&self) -> &dyn Any {
1178        self
1179    }
1180
1181    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
1182        self.body.data_type(input_schema)
1183    }
1184
1185    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
1186        let nullable_then = self
1187            .body
1188            .when_then_expr
1189            .iter()
1190            .filter_map(|(w, t)| {
1191                let is_nullable = match t.nullable(input_schema) {
1192                    // Pass on error determining nullability verbatim
1193                    Err(e) => return Some(Err(e)),
1194                    Ok(n) => n,
1195                };
1196
1197                // Branches with a then expression that is not nullable do not impact the
1198                // nullability of the case expression.
1199                if !is_nullable {
1200                    return None;
1201                }
1202
1203                // For case-with-expression assume all 'then' expressions are reachable
1204                if self.body.expr.is_some() {
1205                    return Some(Ok(()));
1206                }
1207
1208                // For branches with a nullable 'then' expression, try to determine
1209                // if the 'then' expression is ever reachable in the situation where
1210                // it would evaluate to null.
1211
1212                // Replace the `then` expression with `NULL` in the `when` expression
1213                let with_null = match replace_with_null(w, t.as_ref(), input_schema) {
1214                    Err(e) => return Some(Err(e)),
1215                    Ok(e) => e,
1216                };
1217
1218                // Try to const evaluate the modified `when` expression.
1219                let predicate_result = match evaluate_predicate(&with_null) {
1220                    Err(e) => return Some(Err(e)),
1221                    Ok(b) => b,
1222                };
1223
1224                match predicate_result {
1225                    // Evaluation was inconclusive or true, so the 'then' expression is reachable
1226                    None | Some(true) => Some(Ok(())),
1227                    // Evaluation proves the branch will never be taken.
1228                    // The most common pattern for this is `WHEN x IS NOT NULL THEN x`.
1229                    Some(false) => None,
1230                }
1231            })
1232            .next();
1233
1234        if let Some(nullable_then) = nullable_then {
1235            // There is at least one reachable nullable 'then' expression, so the case
1236            // expression itself is nullable.
1237            // Use `Result::map` to propagate the error from `nullable_then` if there is one.
1238            nullable_then.map(|_| true)
1239        } else if let Some(e) = &self.body.else_expr {
1240            // There are no reachable nullable 'then' expressions, so all we still need to
1241            // check is the 'else' expression's nullability.
1242            e.nullable(input_schema)
1243        } else {
1244            // CASE produces NULL if there is no `else` expr
1245            // (aka when none of the `when_then_exprs` match)
1246            Ok(true)
1247        }
1248    }
1249
1250    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
1251        match &self.eval_method {
1252            EvalMethod::WithExpression(p) => {
1253                // this use case evaluates "expr" and then compares the values with the "when"
1254                // values
1255                self.case_when_with_expr(batch, p)
1256            }
1257            EvalMethod::NoExpression(p) => {
1258                // The "when" conditions all evaluate to boolean in this use case and can be
1259                // arbitrary expressions
1260                self.case_when_no_expr(batch, p)
1261            }
1262            EvalMethod::InfallibleExprOrNull => {
1263                // Specialization for CASE WHEN expr THEN column [ELSE NULL] END
1264                self.case_column_or_null(batch)
1265            }
1266            EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch),
1267            EvalMethod::ExpressionOrExpression(p) => self.expr_or_expr(batch, p),
1268            EvalMethod::WithExprScalarLookupTable(lookup_table) => {
1269                self.with_lookup_table(batch, lookup_table)
1270            }
1271        }
1272    }
1273
1274    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
1275        let mut children = vec![];
1276        if let Some(expr) = &self.body.expr {
1277            children.push(expr)
1278        }
1279        self.body.when_then_expr.iter().for_each(|(cond, value)| {
1280            children.push(cond);
1281            children.push(value);
1282        });
1283
1284        if let Some(else_expr) = &self.body.else_expr {
1285            children.push(else_expr)
1286        }
1287        children
1288    }
1289
1290    // For physical CaseExpr, we do not allow modifying children size
1291    fn with_new_children(
1292        self: Arc<Self>,
1293        children: Vec<Arc<dyn PhysicalExpr>>,
1294    ) -> Result<Arc<dyn PhysicalExpr>> {
1295        if children.len() != self.children().len() {
1296            internal_err!("CaseExpr: Wrong number of children")
1297        } else {
1298            let (expr, when_then_expr, else_expr) =
1299                match (self.expr().is_some(), self.body.else_expr.is_some()) {
1300                    (true, true) => (
1301                        Some(&children[0]),
1302                        &children[1..children.len() - 1],
1303                        Some(&children[children.len() - 1]),
1304                    ),
1305                    (true, false) => {
1306                        (Some(&children[0]), &children[1..children.len()], None)
1307                    }
1308                    (false, true) => (
1309                        None,
1310                        &children[0..children.len() - 1],
1311                        Some(&children[children.len() - 1]),
1312                    ),
1313                    (false, false) => (None, &children[0..children.len()], None),
1314                };
1315            Ok(Arc::new(CaseExpr::try_new(
1316                expr.cloned(),
1317                when_then_expr.iter().cloned().tuples().collect(),
1318                else_expr.cloned(),
1319            )?))
1320        }
1321    }
1322
1323    fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1324        write!(f, "CASE ")?;
1325        if let Some(e) = &self.body.expr {
1326            e.fmt_sql(f)?;
1327            write!(f, " ")?;
1328        }
1329
1330        for (w, t) in &self.body.when_then_expr {
1331            write!(f, "WHEN ")?;
1332            w.fmt_sql(f)?;
1333            write!(f, " THEN ")?;
1334            t.fmt_sql(f)?;
1335            write!(f, " ")?;
1336        }
1337
1338        if let Some(e) = &self.body.else_expr {
1339            write!(f, "ELSE ")?;
1340            e.fmt_sql(f)?;
1341            write!(f, " ")?;
1342        }
1343        write!(f, "END")
1344    }
1345}
1346
1347/// Attempts to const evaluate the given `predicate`.
1348/// Returns:
1349/// - `Some(true)` if the predicate evaluates to a truthy value.
1350/// - `Some(false)` if the predicate evaluates to a falsy value.
1351/// - `None` if the predicate could not be evaluated.
1352fn evaluate_predicate(predicate: &Arc<dyn PhysicalExpr>) -> Result<Option<bool>> {
1353    // Create a dummy record with no columns and one row
1354    let batch = RecordBatch::try_new_with_options(
1355        Arc::new(Schema::empty()),
1356        vec![],
1357        &RecordBatchOptions::new().with_row_count(Some(1)),
1358    )?;
1359
1360    // Evaluate the predicate and interpret the result as a boolean
1361    let result = match predicate.evaluate(&batch) {
1362        // An error during evaluation means we couldn't const evaluate the predicate, so return `None`
1363        Err(_) => None,
1364        Ok(ColumnarValue::Array(array)) => Some(
1365            ScalarValue::try_from_array(array.as_ref(), 0)?
1366                .cast_to(&DataType::Boolean)?,
1367        ),
1368        Ok(ColumnarValue::Scalar(scalar)) => Some(scalar.cast_to(&DataType::Boolean)?),
1369    };
1370    Ok(result.map(|v| matches!(v, ScalarValue::Boolean(Some(true)))))
1371}
1372
1373fn replace_with_null(
1374    expr: &Arc<dyn PhysicalExpr>,
1375    expr_to_replace: &dyn PhysicalExpr,
1376    input_schema: &Schema,
1377) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
1378    let with_null = Arc::clone(expr)
1379        .transform_down(|e| {
1380            if e.as_ref().dyn_eq(expr_to_replace) {
1381                let data_type = e.data_type(input_schema)?;
1382                let null_literal = lit(ScalarValue::try_new_null(&data_type)?);
1383                Ok(Transformed::yes(null_literal))
1384            } else {
1385                Ok(Transformed::no(e))
1386            }
1387        })?
1388        .data;
1389    Ok(with_null)
1390}
1391
1392/// Create a CASE expression
1393pub fn case(
1394    expr: Option<Arc<dyn PhysicalExpr>>,
1395    when_thens: Vec<WhenThen>,
1396    else_expr: Option<Arc<dyn PhysicalExpr>>,
1397) -> Result<Arc<dyn PhysicalExpr>> {
1398    Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?))
1399}
1400
1401#[cfg(test)]
1402mod tests {
1403    use super::*;
1404
1405    use crate::expressions;
1406    use crate::expressions::{BinaryExpr, binary, cast, col, is_not_null, lit};
1407    use arrow::buffer::Buffer;
1408    use arrow::datatypes::DataType::Float64;
1409    use arrow::datatypes::Field;
1410    use datafusion_common::cast::{as_float64_array, as_int32_array};
1411    use datafusion_common::plan_err;
1412    use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
1413    use datafusion_expr::type_coercion::binary::comparison_coercion;
1414    use datafusion_expr_common::operator::Operator;
1415    use datafusion_physical_expr_common::physical_expr::fmt_sql;
1416    use half::f16;
1417
1418    #[test]
1419    fn case_with_expr() -> Result<()> {
1420        let batch = case_test_batch()?;
1421        let schema = batch.schema();
1422
1423        // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END
1424        let when1 = lit("foo");
1425        let then1 = lit(123i32);
1426        let when2 = lit("bar");
1427        let then2 = lit(456i32);
1428
1429        let expr = generate_case_when_with_type_coercion(
1430            Some(col("a", &schema)?),
1431            vec![(when1, then1), (when2, then2)],
1432            None,
1433            schema.as_ref(),
1434        )?;
1435        let result = expr
1436            .evaluate(&batch)?
1437            .into_array(batch.num_rows())
1438            .expect("Failed to convert to array");
1439        let result = as_int32_array(&result)?;
1440
1441        let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1442
1443        assert_eq!(expected, result);
1444
1445        Ok(())
1446    }
1447
1448    #[test]
1449    fn case_with_expr_dictionary() -> Result<()> {
1450        let schema = Schema::new(vec![Field::new(
1451            "a",
1452            DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
1453            true,
1454        )]);
1455        let keys = UInt8Array::from(vec![0u8, 1u8, 2u8, 3u8]);
1456        let values = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
1457        let dictionary = DictionaryArray::new(keys, Arc::new(values));
1458        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?;
1459
1460        let schema = batch.schema();
1461
1462        // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END
1463        let when1 = lit("foo");
1464        let then1 = lit(123i32);
1465        let when2 = lit("bar");
1466        let then2 = lit(456i32);
1467
1468        let expr = generate_case_when_with_type_coercion(
1469            Some(col("a", &schema)?),
1470            vec![(when1, then1), (when2, then2)],
1471            None,
1472            schema.as_ref(),
1473        )?;
1474        let result = expr
1475            .evaluate(&batch)?
1476            .into_array(batch.num_rows())
1477            .expect("Failed to convert to array");
1478        let result = as_int32_array(&result)?;
1479
1480        let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1481
1482        assert_eq!(expected, result);
1483
1484        Ok(())
1485    }
1486
1487    // Make sure we are not failing when got literal in case when but input is dictionary encoded
1488    #[test]
1489    fn case_with_expr_primitive_dictionary() -> Result<()> {
1490        let schema = Schema::new(vec![Field::new(
1491            "a",
1492            DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::UInt64)),
1493            true,
1494        )]);
1495        let keys = UInt8Array::from(vec![0u8, 1u8, 2u8, 3u8]);
1496        let values = UInt64Array::from(vec![Some(10), Some(20), None, Some(30)]);
1497        let dictionary = DictionaryArray::new(keys, Arc::new(values));
1498        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?;
1499
1500        let schema = batch.schema();
1501
1502        // CASE a WHEN 10 THEN 123 WHEN 30 THEN 456 END
1503        let when1 = lit(10_u64);
1504        let then1 = lit(123_i32);
1505        let when2 = lit(30_u64);
1506        let then2 = lit(456_i32);
1507
1508        let expr = generate_case_when_with_type_coercion(
1509            Some(col("a", &schema)?),
1510            vec![(when1, then1), (when2, then2)],
1511            None,
1512            schema.as_ref(),
1513        )?;
1514        let result = expr
1515            .evaluate(&batch)?
1516            .into_array(batch.num_rows())
1517            .expect("Failed to convert to array");
1518        let result = as_int32_array(&result)?;
1519
1520        let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1521
1522        assert_eq!(expected, result);
1523
1524        Ok(())
1525    }
1526
1527    // Make sure we are not failing when got literal in case when but input is dictionary encoded
1528    #[test]
1529    fn case_with_expr_boolean_dictionary() -> Result<()> {
1530        let schema = Schema::new(vec![Field::new(
1531            "a",
1532            DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Boolean)),
1533            true,
1534        )]);
1535        let keys = UInt8Array::from(vec![0u8, 1u8, 2u8, 3u8]);
1536        let values = BooleanArray::from(vec![Some(true), Some(false), None, Some(true)]);
1537        let dictionary = DictionaryArray::new(keys, Arc::new(values));
1538        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?;
1539
1540        let schema = batch.schema();
1541
1542        // CASE a WHEN true THEN 123 WHEN false THEN 456 END
1543        let when1 = lit(true);
1544        let then1 = lit(123i32);
1545        let when2 = lit(false);
1546        let then2 = lit(456i32);
1547
1548        let expr = generate_case_when_with_type_coercion(
1549            Some(col("a", &schema)?),
1550            vec![(when1, then1), (when2, then2)],
1551            None,
1552            schema.as_ref(),
1553        )?;
1554        let result = expr
1555            .evaluate(&batch)?
1556            .into_array(batch.num_rows())
1557            .expect("Failed to convert to array");
1558        let result = as_int32_array(&result)?;
1559
1560        let expected = &Int32Array::from(vec![Some(123), Some(456), None, Some(123)]);
1561
1562        assert_eq!(expected, result);
1563
1564        Ok(())
1565    }
1566
1567    #[test]
1568    fn case_with_expr_all_null_dictionary() -> Result<()> {
1569        let schema = Schema::new(vec![Field::new(
1570            "a",
1571            DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
1572            true,
1573        )]);
1574        let keys = UInt8Array::from(vec![2u8, 2u8, 2u8, 2u8]);
1575        let values = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
1576        let dictionary = DictionaryArray::new(keys, Arc::new(values));
1577        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?;
1578
1579        let schema = batch.schema();
1580
1581        // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END
1582        let when1 = lit("foo");
1583        let then1 = lit(123i32);
1584        let when2 = lit("bar");
1585        let then2 = lit(456i32);
1586
1587        let expr = generate_case_when_with_type_coercion(
1588            Some(col("a", &schema)?),
1589            vec![(when1, then1), (when2, then2)],
1590            None,
1591            schema.as_ref(),
1592        )?;
1593        let result = expr
1594            .evaluate(&batch)?
1595            .into_array(batch.num_rows())
1596            .expect("Failed to convert to array");
1597        let result = as_int32_array(&result)?;
1598
1599        let expected = &Int32Array::from(vec![None, None, None, None]);
1600
1601        assert_eq!(expected, result);
1602
1603        Ok(())
1604    }
1605
1606    #[test]
1607    fn case_with_expr_else() -> Result<()> {
1608        let batch = case_test_batch()?;
1609        let schema = batch.schema();
1610
1611        // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 ELSE 999 END
1612        let when1 = lit("foo");
1613        let then1 = lit(123i32);
1614        let when2 = lit("bar");
1615        let then2 = lit(456i32);
1616        let else_value = lit(999i32);
1617
1618        let expr = generate_case_when_with_type_coercion(
1619            Some(col("a", &schema)?),
1620            vec![(when1, then1), (when2, then2)],
1621            Some(else_value),
1622            schema.as_ref(),
1623        )?;
1624        let result = expr
1625            .evaluate(&batch)?
1626            .into_array(batch.num_rows())
1627            .expect("Failed to convert to array");
1628        let result = as_int32_array(&result)?;
1629
1630        let expected =
1631            &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
1632
1633        assert_eq!(expected, result);
1634
1635        Ok(())
1636    }
1637
1638    #[test]
1639    fn case_with_expr_divide_by_zero() -> Result<()> {
1640        let batch = case_test_batch1()?;
1641        let schema = batch.schema();
1642
1643        // CASE a when 0 THEN float64(null) ELSE 25.0 / cast(a, float64)  END
1644        let when1 = lit(0i32);
1645        let then1 = lit(ScalarValue::Float64(None));
1646        let else_value = binary(
1647            lit(25.0f64),
1648            Operator::Divide,
1649            cast(col("a", &schema)?, &batch.schema(), Float64)?,
1650            &batch.schema(),
1651        )?;
1652
1653        let expr = generate_case_when_with_type_coercion(
1654            Some(col("a", &schema)?),
1655            vec![(when1, then1)],
1656            Some(else_value),
1657            schema.as_ref(),
1658        )?;
1659        let result = expr
1660            .evaluate(&batch)?
1661            .into_array(batch.num_rows())
1662            .expect("Failed to convert to array");
1663        let result =
1664            as_float64_array(&result).expect("failed to downcast to Float64Array");
1665
1666        let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
1667
1668        assert_eq!(expected, result);
1669
1670        Ok(())
1671    }
1672
1673    #[test]
1674    fn case_without_expr() -> Result<()> {
1675        let batch = case_test_batch()?;
1676        let schema = batch.schema();
1677
1678        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 END
1679        let when1 = binary(
1680            col("a", &schema)?,
1681            Operator::Eq,
1682            lit("foo"),
1683            &batch.schema(),
1684        )?;
1685        let then1 = lit(123i32);
1686        let when2 = binary(
1687            col("a", &schema)?,
1688            Operator::Eq,
1689            lit("bar"),
1690            &batch.schema(),
1691        )?;
1692        let then2 = lit(456i32);
1693
1694        let expr = generate_case_when_with_type_coercion(
1695            None,
1696            vec![(when1, then1), (when2, then2)],
1697            None,
1698            schema.as_ref(),
1699        )?;
1700        let result = expr
1701            .evaluate(&batch)?
1702            .into_array(batch.num_rows())
1703            .expect("Failed to convert to array");
1704        let result = as_int32_array(&result)?;
1705
1706        let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1707
1708        assert_eq!(expected, result);
1709
1710        Ok(())
1711    }
1712
1713    #[test]
1714    fn case_with_expr_when_null() -> Result<()> {
1715        let batch = case_test_batch()?;
1716        let schema = batch.schema();
1717
1718        // CASE a WHEN NULL THEN 0 WHEN a THEN 123 ELSE 999 END
1719        let when1 = lit(ScalarValue::Utf8(None));
1720        let then1 = lit(0i32);
1721        let when2 = col("a", &schema)?;
1722        let then2 = lit(123i32);
1723        let else_value = lit(999i32);
1724
1725        let expr = generate_case_when_with_type_coercion(
1726            Some(col("a", &schema)?),
1727            vec![(when1, then1), (when2, then2)],
1728            Some(else_value),
1729            schema.as_ref(),
1730        )?;
1731        let result = expr
1732            .evaluate(&batch)?
1733            .into_array(batch.num_rows())
1734            .expect("Failed to convert to array");
1735        let result = as_int32_array(&result)?;
1736
1737        let expected =
1738            &Int32Array::from(vec![Some(123), Some(123), Some(999), Some(123)]);
1739
1740        assert_eq!(expected, result);
1741
1742        Ok(())
1743    }
1744
1745    #[test]
1746    fn case_without_expr_divide_by_zero() -> Result<()> {
1747        let batch = case_test_batch1()?;
1748        let schema = batch.schema();
1749
1750        // CASE WHEN a > 0 THEN 25.0 / cast(a, float64) ELSE float64(null) END
1751        let when1 = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &batch.schema())?;
1752        let then1 = binary(
1753            lit(25.0f64),
1754            Operator::Divide,
1755            cast(col("a", &schema)?, &batch.schema(), Float64)?,
1756            &batch.schema(),
1757        )?;
1758        let x = lit(ScalarValue::Float64(None));
1759
1760        let expr = generate_case_when_with_type_coercion(
1761            None,
1762            vec![(when1, then1)],
1763            Some(x),
1764            schema.as_ref(),
1765        )?;
1766        let result = expr
1767            .evaluate(&batch)?
1768            .into_array(batch.num_rows())
1769            .expect("Failed to convert to array");
1770        let result =
1771            as_float64_array(&result).expect("failed to downcast to Float64Array");
1772
1773        let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
1774
1775        assert_eq!(expected, result);
1776
1777        Ok(())
1778    }
1779
1780    fn case_test_batch1() -> Result<RecordBatch> {
1781        let schema = Schema::new(vec![
1782            Field::new("a", DataType::Int32, true),
1783            Field::new("b", DataType::Int32, true),
1784            Field::new("c", DataType::Int32, true),
1785        ]);
1786        let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]);
1787        let b = Int32Array::from(vec![Some(3), None, Some(14), Some(7)]);
1788        let c = Int32Array::from(vec![Some(0), Some(-3), Some(777), None]);
1789        let batch = RecordBatch::try_new(
1790            Arc::new(schema),
1791            vec![Arc::new(a), Arc::new(b), Arc::new(c)],
1792        )?;
1793        Ok(batch)
1794    }
1795
1796    #[test]
1797    fn case_without_expr_else() -> Result<()> {
1798        let batch = case_test_batch()?;
1799        let schema = batch.schema();
1800
1801        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 999 END
1802        let when1 = binary(
1803            col("a", &schema)?,
1804            Operator::Eq,
1805            lit("foo"),
1806            &batch.schema(),
1807        )?;
1808        let then1 = lit(123i32);
1809        let when2 = binary(
1810            col("a", &schema)?,
1811            Operator::Eq,
1812            lit("bar"),
1813            &batch.schema(),
1814        )?;
1815        let then2 = lit(456i32);
1816        let else_value = lit(999i32);
1817
1818        let expr = generate_case_when_with_type_coercion(
1819            None,
1820            vec![(when1, then1), (when2, then2)],
1821            Some(else_value),
1822            schema.as_ref(),
1823        )?;
1824        let result = expr
1825            .evaluate(&batch)?
1826            .into_array(batch.num_rows())
1827            .expect("Failed to convert to array");
1828        let result = as_int32_array(&result)?;
1829
1830        let expected =
1831            &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
1832
1833        assert_eq!(expected, result);
1834
1835        Ok(())
1836    }
1837
1838    #[test]
1839    fn case_with_type_cast() -> Result<()> {
1840        let batch = case_test_batch()?;
1841        let schema = batch.schema();
1842
1843        // CASE WHEN a = 'foo' THEN 123.3 ELSE 999 END
1844        let when = binary(
1845            col("a", &schema)?,
1846            Operator::Eq,
1847            lit("foo"),
1848            &batch.schema(),
1849        )?;
1850        let then = lit(123.3f64);
1851        let else_value = lit(999i32);
1852
1853        let expr = generate_case_when_with_type_coercion(
1854            None,
1855            vec![(when, then)],
1856            Some(else_value),
1857            schema.as_ref(),
1858        )?;
1859        let result = expr
1860            .evaluate(&batch)?
1861            .into_array(batch.num_rows())
1862            .expect("Failed to convert to array");
1863        let result =
1864            as_float64_array(&result).expect("failed to downcast to Float64Array");
1865
1866        let expected =
1867            &Float64Array::from(vec![Some(123.3), Some(999.0), Some(999.0), Some(999.0)]);
1868
1869        assert_eq!(expected, result);
1870
1871        Ok(())
1872    }
1873
1874    #[test]
1875    fn case_with_matches_and_nulls() -> Result<()> {
1876        let batch = case_test_batch_nulls()?;
1877        let schema = batch.schema();
1878
1879        // SELECT CASE WHEN load4 = 1.77 THEN load4 END
1880        let when = binary(
1881            col("load4", &schema)?,
1882            Operator::Eq,
1883            lit(1.77f64),
1884            &batch.schema(),
1885        )?;
1886        let then = col("load4", &schema)?;
1887
1888        let expr = generate_case_when_with_type_coercion(
1889            None,
1890            vec![(when, then)],
1891            None,
1892            schema.as_ref(),
1893        )?;
1894        let result = expr
1895            .evaluate(&batch)?
1896            .into_array(batch.num_rows())
1897            .expect("Failed to convert to array");
1898        let result =
1899            as_float64_array(&result).expect("failed to downcast to Float64Array");
1900
1901        let expected =
1902            &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
1903
1904        assert_eq!(expected, result);
1905
1906        Ok(())
1907    }
1908
1909    #[test]
1910    fn case_with_scalar_predicate() -> Result<()> {
1911        let batch = case_test_batch_nulls()?;
1912        let schema = batch.schema();
1913
1914        // SELECT CASE WHEN TRUE THEN load4 END
1915        let when = lit(true);
1916        let then = col("load4", &schema)?;
1917        let expr = generate_case_when_with_type_coercion(
1918            None,
1919            vec![(when, then)],
1920            None,
1921            schema.as_ref(),
1922        )?;
1923
1924        // many rows
1925        let result = expr
1926            .evaluate(&batch)?
1927            .into_array(batch.num_rows())
1928            .expect("Failed to convert to array");
1929        let result =
1930            as_float64_array(&result).expect("failed to downcast to Float64Array");
1931        let expected = &Float64Array::from(vec![
1932            Some(1.77),
1933            None,
1934            None,
1935            Some(1.78),
1936            None,
1937            Some(1.77),
1938        ]);
1939        assert_eq!(expected, result);
1940
1941        // one row
1942        let expected = Float64Array::from(vec![Some(1.1)]);
1943        let batch =
1944            RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(expected.clone())])?;
1945        let result = expr
1946            .evaluate(&batch)?
1947            .into_array(batch.num_rows())
1948            .expect("Failed to convert to array");
1949        let result =
1950            as_float64_array(&result).expect("failed to downcast to Float64Array");
1951        assert_eq!(&expected, result);
1952
1953        Ok(())
1954    }
1955
1956    #[test]
1957    fn case_expr_matches_and_nulls() -> Result<()> {
1958        let batch = case_test_batch_nulls()?;
1959        let schema = batch.schema();
1960
1961        // SELECT CASE load4 WHEN 1.77 THEN load4 END
1962        let expr = col("load4", &schema)?;
1963        let when = lit(1.77f64);
1964        let then = col("load4", &schema)?;
1965
1966        let expr = generate_case_when_with_type_coercion(
1967            Some(expr),
1968            vec![(when, then)],
1969            None,
1970            schema.as_ref(),
1971        )?;
1972        let result = expr
1973            .evaluate(&batch)?
1974            .into_array(batch.num_rows())
1975            .expect("Failed to convert to array");
1976        let result =
1977            as_float64_array(&result).expect("failed to downcast to Float64Array");
1978
1979        let expected =
1980            &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
1981
1982        assert_eq!(expected, result);
1983
1984        Ok(())
1985    }
1986
1987    #[test]
1988    fn test_when_null_and_some_cond_else_null() -> Result<()> {
1989        let batch = case_test_batch()?;
1990        let schema = batch.schema();
1991
1992        let when = binary(
1993            Arc::new(Literal::new(ScalarValue::Boolean(None))),
1994            Operator::And,
1995            binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?,
1996            &schema,
1997        )?;
1998        let then = col("a", &schema)?;
1999
2000        // SELECT CASE WHEN (NULL AND a = 'foo') THEN a ELSE NULL END
2001        let expr = Arc::new(CaseExpr::try_new(None, vec![(when, then)], None)?);
2002        let result = expr
2003            .evaluate(&batch)?
2004            .into_array(batch.num_rows())
2005            .expect("Failed to convert to array");
2006        let result = as_string_array(&result);
2007
2008        // all result values should be null
2009        assert_eq!(result.logical_null_count(), batch.num_rows());
2010        Ok(())
2011    }
2012
2013    fn case_test_batch() -> Result<RecordBatch> {
2014        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
2015        let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
2016        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
2017        Ok(batch)
2018    }
2019
2020    // Construct an array that has several NULL values whose
2021    // underlying buffer actually matches the where expr predicate
2022    fn case_test_batch_nulls() -> Result<RecordBatch> {
2023        let load4: Float64Array = vec![
2024            Some(1.77), // 1.77
2025            Some(1.77), // null <-- same value, but will be set to null
2026            Some(1.77), // null <-- same value, but will be set to null
2027            Some(1.78), // 1.78
2028            None,       // null
2029            Some(1.77), // 1.77
2030        ]
2031        .into_iter()
2032        .collect();
2033
2034        let null_buffer = Buffer::from([0b00101001u8]);
2035        let load4 = load4
2036            .into_data()
2037            .into_builder()
2038            .null_bit_buffer(Some(null_buffer))
2039            .build()
2040            .unwrap();
2041        let load4: Float64Array = load4.into();
2042
2043        let batch =
2044            RecordBatch::try_from_iter(vec![("load4", Arc::new(load4) as ArrayRef)])?;
2045        Ok(batch)
2046    }
2047
2048    #[test]
2049    fn case_test_incompatible() -> Result<()> {
2050        // 1 then is int64
2051        // 2 then is boolean
2052        let batch = case_test_batch()?;
2053        let schema = batch.schema();
2054
2055        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN true END
2056        let when1 = binary(
2057            col("a", &schema)?,
2058            Operator::Eq,
2059            lit("foo"),
2060            &batch.schema(),
2061        )?;
2062        let then1 = lit(123i32);
2063        let when2 = binary(
2064            col("a", &schema)?,
2065            Operator::Eq,
2066            lit("bar"),
2067            &batch.schema(),
2068        )?;
2069        let then2 = lit(true);
2070
2071        let expr = generate_case_when_with_type_coercion(
2072            None,
2073            vec![(when1, then1), (when2, then2)],
2074            None,
2075            schema.as_ref(),
2076        );
2077        assert!(expr.is_err());
2078
2079        // then 1 is int32
2080        // then 2 is int64
2081        // else is float
2082        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 1.23 END
2083        let when1 = binary(
2084            col("a", &schema)?,
2085            Operator::Eq,
2086            lit("foo"),
2087            &batch.schema(),
2088        )?;
2089        let then1 = lit(123i32);
2090        let when2 = binary(
2091            col("a", &schema)?,
2092            Operator::Eq,
2093            lit("bar"),
2094            &batch.schema(),
2095        )?;
2096        let then2 = lit(456i64);
2097        let else_expr = lit(1.23f64);
2098
2099        let expr = generate_case_when_with_type_coercion(
2100            None,
2101            vec![(when1, then1), (when2, then2)],
2102            Some(else_expr),
2103            schema.as_ref(),
2104        );
2105        assert!(expr.is_ok());
2106        let result_type = expr.unwrap().data_type(schema.as_ref())?;
2107        assert_eq!(Float64, result_type);
2108        Ok(())
2109    }
2110
2111    #[test]
2112    fn case_eq() -> Result<()> {
2113        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
2114
2115        let when1 = lit("foo");
2116        let then1 = lit(123i32);
2117        let when2 = lit("bar");
2118        let then2 = lit(456i32);
2119        let else_value = lit(999i32);
2120
2121        let expr1 = generate_case_when_with_type_coercion(
2122            Some(col("a", &schema)?),
2123            vec![
2124                (Arc::clone(&when1), Arc::clone(&then1)),
2125                (Arc::clone(&when2), Arc::clone(&then2)),
2126            ],
2127            Some(Arc::clone(&else_value)),
2128            &schema,
2129        )?;
2130
2131        let expr2 = generate_case_when_with_type_coercion(
2132            Some(col("a", &schema)?),
2133            vec![
2134                (Arc::clone(&when1), Arc::clone(&then1)),
2135                (Arc::clone(&when2), Arc::clone(&then2)),
2136            ],
2137            Some(Arc::clone(&else_value)),
2138            &schema,
2139        )?;
2140
2141        let expr3 = generate_case_when_with_type_coercion(
2142            Some(col("a", &schema)?),
2143            vec![(Arc::clone(&when1), Arc::clone(&then1)), (when2, then2)],
2144            None,
2145            &schema,
2146        )?;
2147
2148        let expr4 = generate_case_when_with_type_coercion(
2149            Some(col("a", &schema)?),
2150            vec![(when1, then1)],
2151            Some(else_value),
2152            &schema,
2153        )?;
2154
2155        assert!(expr1.eq(&expr2));
2156        assert!(expr2.eq(&expr1));
2157
2158        assert!(expr2.ne(&expr3));
2159        assert!(expr3.ne(&expr2));
2160
2161        assert!(expr1.ne(&expr4));
2162        assert!(expr4.ne(&expr1));
2163
2164        Ok(())
2165    }
2166
2167    #[test]
2168    fn case_transform() -> Result<()> {
2169        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
2170
2171        let when1 = lit("foo");
2172        let then1 = lit(123i32);
2173        let when2 = lit("bar");
2174        let then2 = lit(456i32);
2175        let else_value = lit(999i32);
2176
2177        let expr = generate_case_when_with_type_coercion(
2178            Some(col("a", &schema)?),
2179            vec![
2180                (Arc::clone(&when1), Arc::clone(&then1)),
2181                (Arc::clone(&when2), Arc::clone(&then2)),
2182            ],
2183            Some(Arc::clone(&else_value)),
2184            &schema,
2185        )?;
2186
2187        let expr2 = Arc::clone(&expr)
2188            .transform(|e| {
2189                let transformed = match e.as_any().downcast_ref::<Literal>() {
2190                    Some(lit_value) => match lit_value.value() {
2191                        ScalarValue::Utf8(Some(str_value)) => {
2192                            Some(lit(str_value.to_uppercase()))
2193                        }
2194                        _ => None,
2195                    },
2196                    _ => None,
2197                };
2198                Ok(if let Some(transformed) = transformed {
2199                    Transformed::yes(transformed)
2200                } else {
2201                    Transformed::no(e)
2202                })
2203            })
2204            .data()
2205            .unwrap();
2206
2207        let expr3 = Arc::clone(&expr)
2208            .transform_down(|e| {
2209                let transformed = match e.as_any().downcast_ref::<Literal>() {
2210                    Some(lit_value) => match lit_value.value() {
2211                        ScalarValue::Utf8(Some(str_value)) => {
2212                            Some(lit(str_value.to_uppercase()))
2213                        }
2214                        _ => None,
2215                    },
2216                    _ => None,
2217                };
2218                Ok(if let Some(transformed) = transformed {
2219                    Transformed::yes(transformed)
2220                } else {
2221                    Transformed::no(e)
2222                })
2223            })
2224            .data()
2225            .unwrap();
2226
2227        assert!(expr.ne(&expr2));
2228        assert!(expr2.eq(&expr3));
2229
2230        Ok(())
2231    }
2232
2233    #[test]
2234    fn test_column_or_null_specialization() -> Result<()> {
2235        // create input data
2236        let mut c1 = Int32Builder::new();
2237        let mut c2 = StringBuilder::new();
2238        for i in 0..1000 {
2239            c1.append_value(i);
2240            if i % 7 == 0 {
2241                c2.append_null();
2242            } else {
2243                c2.append_value(format!("string {i}"));
2244            }
2245        }
2246        let c1 = Arc::new(c1.finish());
2247        let c2 = Arc::new(c2.finish());
2248        let schema = Schema::new(vec![
2249            Field::new("c1", DataType::Int32, true),
2250            Field::new("c2", DataType::Utf8, true),
2251        ]);
2252        let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap();
2253
2254        // CaseWhenExprOrNull should produce same results as CaseExpr
2255        let predicate = Arc::new(BinaryExpr::new(
2256            make_col("c1", 0),
2257            Operator::LtEq,
2258            make_lit_i32(250),
2259        ));
2260        let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?;
2261        assert!(matches!(expr.eval_method, EvalMethod::InfallibleExprOrNull));
2262        match expr.evaluate(&batch)? {
2263            ColumnarValue::Array(array) => {
2264                assert_eq!(1000, array.len());
2265                assert_eq!(785, array.null_count());
2266            }
2267            _ => unreachable!(),
2268        }
2269        Ok(())
2270    }
2271
2272    #[test]
2273    fn test_expr_or_expr_specialization() -> Result<()> {
2274        let batch = case_test_batch1()?;
2275        let schema = batch.schema();
2276        let when = binary(
2277            col("a", &schema)?,
2278            Operator::LtEq,
2279            lit(2i32),
2280            &batch.schema(),
2281        )?;
2282        let then = col("b", &schema)?;
2283        let else_expr = col("c", &schema)?;
2284        let expr = CaseExpr::try_new(None, vec![(when, then)], Some(else_expr))?;
2285        assert!(matches!(
2286            expr.eval_method,
2287            EvalMethod::ExpressionOrExpression(_)
2288        ));
2289        let result = expr
2290            .evaluate(&batch)?
2291            .into_array(batch.num_rows())
2292            .expect("Failed to convert to array");
2293        let result = as_int32_array(&result).expect("failed to downcast to Int32Array");
2294
2295        let expected = &Int32Array::from(vec![Some(3), None, Some(777), None]);
2296
2297        assert_eq!(expected, result);
2298        Ok(())
2299    }
2300
2301    fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
2302        Arc::new(Column::new(name, index))
2303    }
2304
2305    fn make_lit_i32(n: i32) -> Arc<dyn PhysicalExpr> {
2306        Arc::new(Literal::new(ScalarValue::Int32(Some(n))))
2307    }
2308
2309    fn generate_case_when_with_type_coercion(
2310        expr: Option<Arc<dyn PhysicalExpr>>,
2311        when_thens: Vec<WhenThen>,
2312        else_expr: Option<Arc<dyn PhysicalExpr>>,
2313        input_schema: &Schema,
2314    ) -> Result<Arc<dyn PhysicalExpr>> {
2315        let coerce_type =
2316            get_case_common_type(&when_thens, else_expr.clone(), input_schema);
2317        let (when_thens, else_expr) = match coerce_type {
2318            None => plan_err!(
2319                "Can't get a common type for then {when_thens:?} and else {else_expr:?} expression"
2320            ),
2321            Some(data_type) => {
2322                // cast then expr
2323                let left = when_thens
2324                    .into_iter()
2325                    .map(|(when, then)| {
2326                        let then = try_cast(then, input_schema, data_type.clone())?;
2327                        Ok((when, then))
2328                    })
2329                    .collect::<Result<Vec<_>>>()?;
2330                let right = match else_expr {
2331                    None => None,
2332                    Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?),
2333                };
2334
2335                Ok((left, right))
2336            }
2337        }?;
2338        case(expr, when_thens, else_expr)
2339    }
2340
2341    fn get_case_common_type(
2342        when_thens: &[WhenThen],
2343        else_expr: Option<Arc<dyn PhysicalExpr>>,
2344        input_schema: &Schema,
2345    ) -> Option<DataType> {
2346        let thens_type = when_thens
2347            .iter()
2348            .map(|when_then| {
2349                let data_type = &when_then.1.data_type(input_schema).unwrap();
2350                data_type.clone()
2351            })
2352            .collect::<Vec<_>>();
2353        let else_type = match else_expr {
2354            None => {
2355                // case when then exprs must have one then value
2356                thens_type[0].clone()
2357            }
2358            Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(),
2359        };
2360        thens_type
2361            .iter()
2362            .try_fold(else_type, |left_type, right_type| {
2363                // TODO: now just use the `equal` coercion rule for case when. If find the issue, and
2364                // refactor again.
2365                comparison_coercion(&left_type, right_type)
2366            })
2367    }
2368
2369    #[test]
2370    fn test_fmt_sql() -> Result<()> {
2371        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
2372
2373        // CASE WHEN a = 'foo' THEN 123.3 ELSE 999 END
2374        let when = binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?;
2375        let then = lit(123.3f64);
2376        let else_value = lit(999i32);
2377
2378        let expr = generate_case_when_with_type_coercion(
2379            None,
2380            vec![(when, then)],
2381            Some(else_value),
2382            &schema,
2383        )?;
2384
2385        let display_string = expr.to_string();
2386        assert_eq!(
2387            display_string,
2388            "CASE WHEN a@0 = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END"
2389        );
2390
2391        let sql_string = fmt_sql(expr.as_ref()).to_string();
2392        assert_eq!(
2393            sql_string,
2394            "CASE WHEN a = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END"
2395        );
2396
2397        Ok(())
2398    }
2399
2400    fn when_then_else(
2401        when: &Arc<dyn PhysicalExpr>,
2402        then: &Arc<dyn PhysicalExpr>,
2403        els: &Arc<dyn PhysicalExpr>,
2404    ) -> Result<Arc<dyn PhysicalExpr>> {
2405        let case = CaseExpr::try_new(
2406            None,
2407            vec![(Arc::clone(when), Arc::clone(then))],
2408            Some(Arc::clone(els)),
2409        )?;
2410        Ok(Arc::new(case))
2411    }
2412
2413    #[test]
2414    fn test_case_expression_nullability_with_nullable_column() -> Result<()> {
2415        case_expression_nullability(true)
2416    }
2417
2418    #[test]
2419    fn test_case_expression_nullability_with_not_nullable_column() -> Result<()> {
2420        case_expression_nullability(false)
2421    }
2422
2423    fn case_expression_nullability(col_is_nullable: bool) -> Result<()> {
2424        let schema =
2425            Schema::new(vec![Field::new("foo", DataType::Int32, col_is_nullable)]);
2426
2427        let foo = col("foo", &schema)?;
2428        let foo_is_not_null = is_not_null(Arc::clone(&foo))?;
2429        let foo_is_null = expressions::is_null(Arc::clone(&foo))?;
2430        let not_foo_is_null = expressions::not(Arc::clone(&foo_is_null))?;
2431        let zero = lit(0);
2432        let foo_eq_zero =
2433            binary(Arc::clone(&foo), Operator::Eq, Arc::clone(&zero), &schema)?;
2434
2435        assert_not_nullable(when_then_else(&foo_is_not_null, &foo, &zero)?, &schema);
2436        assert_not_nullable(when_then_else(&not_foo_is_null, &foo, &zero)?, &schema);
2437        assert_not_nullable(when_then_else(&foo_eq_zero, &foo, &zero)?, &schema);
2438
2439        assert_not_nullable(
2440            when_then_else(
2441                &binary(
2442                    Arc::clone(&foo_is_not_null),
2443                    Operator::And,
2444                    Arc::clone(&foo_eq_zero),
2445                    &schema,
2446                )?,
2447                &foo,
2448                &zero,
2449            )?,
2450            &schema,
2451        );
2452
2453        assert_not_nullable(
2454            when_then_else(
2455                &binary(
2456                    Arc::clone(&foo_eq_zero),
2457                    Operator::And,
2458                    Arc::clone(&foo_is_not_null),
2459                    &schema,
2460                )?,
2461                &foo,
2462                &zero,
2463            )?,
2464            &schema,
2465        );
2466
2467        assert_not_nullable(
2468            when_then_else(
2469                &binary(
2470                    Arc::clone(&foo_is_not_null),
2471                    Operator::Or,
2472                    Arc::clone(&foo_eq_zero),
2473                    &schema,
2474                )?,
2475                &foo,
2476                &zero,
2477            )?,
2478            &schema,
2479        );
2480
2481        assert_not_nullable(
2482            when_then_else(
2483                &binary(
2484                    Arc::clone(&foo_eq_zero),
2485                    Operator::Or,
2486                    Arc::clone(&foo_is_not_null),
2487                    &schema,
2488                )?,
2489                &foo,
2490                &zero,
2491            )?,
2492            &schema,
2493        );
2494
2495        assert_nullability(
2496            when_then_else(
2497                &binary(
2498                    Arc::clone(&foo_is_null),
2499                    Operator::Or,
2500                    Arc::clone(&foo_eq_zero),
2501                    &schema,
2502                )?,
2503                &foo,
2504                &zero,
2505            )?,
2506            &schema,
2507            col_is_nullable,
2508        );
2509
2510        assert_nullability(
2511            when_then_else(
2512                &binary(
2513                    binary(Arc::clone(&foo), Operator::Eq, Arc::clone(&zero), &schema)?,
2514                    Operator::Or,
2515                    Arc::clone(&foo_is_null),
2516                    &schema,
2517                )?,
2518                &foo,
2519                &zero,
2520            )?,
2521            &schema,
2522            col_is_nullable,
2523        );
2524
2525        assert_not_nullable(
2526            when_then_else(
2527                &binary(
2528                    binary(
2529                        binary(
2530                            Arc::clone(&foo),
2531                            Operator::Eq,
2532                            Arc::clone(&zero),
2533                            &schema,
2534                        )?,
2535                        Operator::And,
2536                        Arc::clone(&foo_is_not_null),
2537                        &schema,
2538                    )?,
2539                    Operator::Or,
2540                    binary(
2541                        binary(
2542                            Arc::clone(&foo),
2543                            Operator::Eq,
2544                            Arc::clone(&foo),
2545                            &schema,
2546                        )?,
2547                        Operator::And,
2548                        Arc::clone(&foo_is_not_null),
2549                        &schema,
2550                    )?,
2551                    &schema,
2552                )?,
2553                &foo,
2554                &zero,
2555            )?,
2556            &schema,
2557        );
2558
2559        Ok(())
2560    }
2561
2562    fn assert_not_nullable(expr: Arc<dyn PhysicalExpr>, schema: &Schema) {
2563        assert!(!expr.nullable(schema).unwrap());
2564    }
2565
2566    fn assert_nullable(expr: Arc<dyn PhysicalExpr>, schema: &Schema) {
2567        assert!(expr.nullable(schema).unwrap());
2568    }
2569
2570    fn assert_nullability(expr: Arc<dyn PhysicalExpr>, schema: &Schema, nullable: bool) {
2571        if nullable {
2572            assert_nullable(expr, schema);
2573        } else {
2574            assert_not_nullable(expr, schema);
2575        }
2576    }
2577
2578    // Test Lookup evaluation
2579
2580    fn test_case_when_literal_lookup(
2581        values: ArrayRef,
2582        lookup_map: &[(ScalarValue, ScalarValue)],
2583        else_value: Option<ScalarValue>,
2584        expected: ArrayRef,
2585    ) {
2586        // Create lookup
2587        // CASE <expr>
2588        // WHEN <when_constant_1> THEN <then_constant_1>
2589        // WHEN <when_constant_2> THEN <then_constant_2>
2590        // [ ELSE <else_constant> ]
2591
2592        let schema = Schema::new(vec![Field::new(
2593            "a",
2594            values.data_type().clone(),
2595            values.is_nullable(),
2596        )]);
2597        let schema = Arc::new(schema);
2598
2599        let batch = RecordBatch::try_new(schema, vec![values])
2600            .expect("failed to create RecordBatch");
2601
2602        let schema = batch.schema_ref();
2603        let case = col("a", schema).expect("failed to create col");
2604
2605        let when_then = lookup_map
2606            .iter()
2607            .map(|(when, then)| {
2608                (
2609                    Arc::new(Literal::new(when.clone())) as _,
2610                    Arc::new(Literal::new(then.clone())) as _,
2611                )
2612            })
2613            .collect::<Vec<WhenThen>>();
2614
2615        let else_expr = else_value.map(|else_value| {
2616            Arc::new(Literal::new(else_value)) as Arc<dyn PhysicalExpr>
2617        });
2618        let expr = CaseExpr::try_new(Some(case), when_then, else_expr)
2619            .expect("failed to create case");
2620
2621        // Assert that we are testing what we intend to assert
2622        assert!(
2623            matches!(
2624                expr.eval_method,
2625                EvalMethod::WithExprScalarLookupTable { .. }
2626            ),
2627            "we should use the expected eval method"
2628        );
2629
2630        let actual = expr
2631            .evaluate(&batch)
2632            .expect("failed to evaluate case")
2633            .into_array(batch.num_rows())
2634            .expect("Failed to convert to array");
2635
2636        assert_eq!(
2637            actual.data_type(),
2638            expected.data_type(),
2639            "Data type mismatch"
2640        );
2641
2642        assert_eq!(
2643            actual.as_ref(),
2644            expected.as_ref(),
2645            "actual (left) does not match expected (right)"
2646        );
2647    }
2648
2649    fn create_lookup<When, Then>(
2650        when_then_pairs: impl IntoIterator<Item = (When, Then)>,
2651    ) -> Vec<(ScalarValue, ScalarValue)>
2652    where
2653        ScalarValue: From<When>,
2654        ScalarValue: From<Then>,
2655    {
2656        when_then_pairs
2657            .into_iter()
2658            .map(|(when, then)| (ScalarValue::from(when), ScalarValue::from(then)))
2659            .collect()
2660    }
2661
2662    fn create_input_and_expected<Input, Expected, InputFromItem, ExpectedFromItem>(
2663        input_and_expected_pairs: impl IntoIterator<Item = (InputFromItem, ExpectedFromItem)>,
2664    ) -> (Input, Expected)
2665    where
2666        Input: Array + From<Vec<InputFromItem>>,
2667        Expected: Array + From<Vec<ExpectedFromItem>>,
2668    {
2669        let (input_items, expected_items): (Vec<InputFromItem>, Vec<ExpectedFromItem>) =
2670            input_and_expected_pairs.into_iter().unzip();
2671
2672        (Input::from(input_items), Expected::from(expected_items))
2673    }
2674
2675    fn test_lookup_eval_with_and_without_else(
2676        lookup_map: &[(ScalarValue, ScalarValue)],
2677        input_values: ArrayRef,
2678        expected: StringArray,
2679    ) {
2680        // Testing without ELSE should fallback to None
2681        test_case_when_literal_lookup(
2682            Arc::clone(&input_values),
2683            lookup_map,
2684            None,
2685            Arc::new(expected.clone()),
2686        );
2687
2688        // Testing with Else
2689        let else_value = "___fallback___";
2690
2691        // Changing each expected None to be fallback
2692        let expected_with_else = expected
2693            .iter()
2694            .map(|item| item.unwrap_or(else_value))
2695            .map(Some)
2696            .collect::<StringArray>();
2697
2698        // Test case
2699        test_case_when_literal_lookup(
2700            input_values,
2701            lookup_map,
2702            Some(ScalarValue::Utf8(Some(else_value.to_string()))),
2703            Arc::new(expected_with_else),
2704        );
2705    }
2706
2707    #[test]
2708    fn test_case_when_literal_lookup_int32_to_string() {
2709        let lookup_map = create_lookup([
2710            (Some(4), Some("four")),
2711            (Some(2), Some("two")),
2712            (Some(3), Some("three")),
2713            (Some(1), Some("one")),
2714        ]);
2715
2716        let (input_values, expected) =
2717            create_input_and_expected::<Int32Array, StringArray, _, _>([
2718                (1, Some("one")),
2719                (2, Some("two")),
2720                (3, Some("three")),
2721                (3, Some("three")),
2722                (2, Some("two")),
2723                (3, Some("three")),
2724                (5, None), // No match in WHEN
2725                (5, None), // No match in WHEN
2726                (3, Some("three")),
2727                (5, None), // No match in WHEN
2728            ]);
2729
2730        test_lookup_eval_with_and_without_else(
2731            &lookup_map,
2732            Arc::new(input_values),
2733            expected,
2734        );
2735    }
2736
2737    #[test]
2738    fn test_case_when_literal_lookup_none_case_should_never_match() {
2739        let lookup_map = create_lookup([
2740            (Some(4), Some("four")),
2741            (None, Some("none")),
2742            (Some(2), Some("two")),
2743            (Some(1), Some("one")),
2744        ]);
2745
2746        let (input_values, expected) =
2747            create_input_and_expected::<Int32Array, StringArray, _, _>([
2748                (Some(1), Some("one")),
2749                (Some(5), None), // No match in WHEN
2750                (None, None), // None cases are never match in CASE <expr> WHEN <value> syntax
2751                (Some(2), Some("two")),
2752                (None, None), // None cases are never match in CASE <expr> WHEN <value> syntax
2753                (None, None), // None cases are never match in CASE <expr> WHEN <value> syntax
2754                (Some(2), Some("two")),
2755                (Some(5), None), // No match in WHEN
2756            ]);
2757
2758        test_lookup_eval_with_and_without_else(
2759            &lookup_map,
2760            Arc::new(input_values),
2761            expected,
2762        );
2763    }
2764
2765    #[test]
2766    fn test_case_when_literal_lookup_int32_to_string_with_duplicate_cases() {
2767        let lookup_map = create_lookup([
2768            (Some(4), Some("four")),
2769            (Some(4), Some("no 4")),
2770            (Some(2), Some("two")),
2771            (Some(2), Some("no 2")),
2772            (Some(3), Some("three")),
2773            (Some(3), Some("no 3")),
2774            (Some(2), Some("no 2")),
2775            (Some(4), Some("no 4")),
2776            (Some(2), Some("no 2")),
2777            (Some(3), Some("no 3")),
2778            (Some(4), Some("no 4")),
2779            (Some(2), Some("no 2")),
2780            (Some(3), Some("no 3")),
2781            (Some(3), Some("no 3")),
2782        ]);
2783
2784        let (input_values, expected) =
2785            create_input_and_expected::<Int32Array, StringArray, _, _>([
2786                (1, None), // No match in WHEN
2787                (2, Some("two")),
2788                (3, Some("three")),
2789                (3, Some("three")),
2790                (2, Some("two")),
2791                (3, Some("three")),
2792                (5, None), // No match in WHEN
2793                (5, None), // No match in WHEN
2794                (3, Some("three")),
2795                (5, None), // No match in WHEN
2796            ]);
2797
2798        test_lookup_eval_with_and_without_else(
2799            &lookup_map,
2800            Arc::new(input_values),
2801            expected,
2802        );
2803    }
2804
2805    #[test]
2806    fn test_case_when_literal_lookup_f32_to_string_with_special_values_and_duplicate_cases()
2807     {
2808        let lookup_map = create_lookup([
2809            (Some(4.0), Some("four point zero")),
2810            (Some(f32::NAN), Some("NaN")),
2811            (Some(3.2), Some("three point two")),
2812            // Duplicate case to make sure it is not used
2813            (Some(f32::NAN), Some("should not use this NaN branch")),
2814            (Some(f32::INFINITY), Some("Infinity")),
2815            (Some(0.0), Some("zero")),
2816            // Duplicate case to make sure it is not used
2817            (
2818                Some(f32::INFINITY),
2819                Some("should not use this Infinity branch"),
2820            ),
2821            (Some(1.1), Some("one point one")),
2822        ]);
2823
2824        let (input_values, expected) =
2825            create_input_and_expected::<Float32Array, StringArray, _, _>([
2826                (1.1, Some("one point one")),
2827                (f32::NAN, Some("NaN")),
2828                (3.2, Some("three point two")),
2829                (3.2, Some("three point two")),
2830                (0.0, Some("zero")),
2831                (f32::INFINITY, Some("Infinity")),
2832                (3.2, Some("three point two")),
2833                (f32::NEG_INFINITY, None), // No match in WHEN
2834                (f32::NEG_INFINITY, None), // No match in WHEN
2835                (3.2, Some("three point two")),
2836                (-0.0, None), // No match in WHEN
2837            ]);
2838
2839        test_lookup_eval_with_and_without_else(
2840            &lookup_map,
2841            Arc::new(input_values),
2842            expected,
2843        );
2844    }
2845
2846    #[test]
2847    fn test_case_when_literal_lookup_f16_to_string_with_special_values() {
2848        let lookup_map = create_lookup([
2849            (
2850                ScalarValue::Float16(Some(f16::from_f32(3.2))),
2851                Some("3 dot 2"),
2852            ),
2853            (ScalarValue::Float16(Some(f16::NAN)), Some("NaN")),
2854            (
2855                ScalarValue::Float16(Some(f16::from_f32(17.4))),
2856                Some("17 dot 4"),
2857            ),
2858            (ScalarValue::Float16(Some(f16::INFINITY)), Some("Infinity")),
2859            (ScalarValue::Float16(Some(f16::ZERO)), Some("zero")),
2860        ]);
2861
2862        let (input_values, expected) =
2863            create_input_and_expected::<Float16Array, StringArray, _, _>([
2864                (f16::from_f32(3.2), Some("3 dot 2")),
2865                (f16::NAN, Some("NaN")),
2866                (f16::from_f32(17.4), Some("17 dot 4")),
2867                (f16::from_f32(17.4), Some("17 dot 4")),
2868                (f16::INFINITY, Some("Infinity")),
2869                (f16::from_f32(17.4), Some("17 dot 4")),
2870                (f16::NEG_INFINITY, None), // No match in WHEN
2871                (f16::NEG_INFINITY, None), // No match in WHEN
2872                (f16::from_f32(17.4), Some("17 dot 4")),
2873                (f16::NEG_ZERO, None), // No match in WHEN
2874            ]);
2875
2876        test_lookup_eval_with_and_without_else(
2877            &lookup_map,
2878            Arc::new(input_values),
2879            expected,
2880        );
2881    }
2882
2883    #[test]
2884    fn test_case_when_literal_lookup_f32_to_string_with_special_values() {
2885        let lookup_map = create_lookup([
2886            (3.2, Some("3 dot 2")),
2887            (f32::NAN, Some("NaN")),
2888            (17.4, Some("17 dot 4")),
2889            (f32::INFINITY, Some("Infinity")),
2890            (f32::ZERO, Some("zero")),
2891        ]);
2892
2893        let (input_values, expected) =
2894            create_input_and_expected::<Float32Array, StringArray, _, _>([
2895                (3.2, Some("3 dot 2")),
2896                (f32::NAN, Some("NaN")),
2897                (17.4, Some("17 dot 4")),
2898                (17.4, Some("17 dot 4")),
2899                (f32::INFINITY, Some("Infinity")),
2900                (17.4, Some("17 dot 4")),
2901                (f32::NEG_INFINITY, None), // No match in WHEN
2902                (f32::NEG_INFINITY, None), // No match in WHEN
2903                (17.4, Some("17 dot 4")),
2904                (-0.0, None), // No match in WHEN
2905            ]);
2906
2907        test_lookup_eval_with_and_without_else(
2908            &lookup_map,
2909            Arc::new(input_values),
2910            expected,
2911        );
2912    }
2913
2914    #[test]
2915    fn test_case_when_literal_lookup_f64_to_string_with_special_values() {
2916        let lookup_map = create_lookup([
2917            (3.2, Some("3 dot 2")),
2918            (f64::NAN, Some("NaN")),
2919            (17.4, Some("17 dot 4")),
2920            (f64::INFINITY, Some("Infinity")),
2921            (f64::ZERO, Some("zero")),
2922        ]);
2923
2924        let (input_values, expected) =
2925            create_input_and_expected::<Float64Array, StringArray, _, _>([
2926                (3.2, Some("3 dot 2")),
2927                (f64::NAN, Some("NaN")),
2928                (17.4, Some("17 dot 4")),
2929                (17.4, Some("17 dot 4")),
2930                (f64::INFINITY, Some("Infinity")),
2931                (17.4, Some("17 dot 4")),
2932                (f64::NEG_INFINITY, None), // No match in WHEN
2933                (f64::NEG_INFINITY, None), // No match in WHEN
2934                (17.4, Some("17 dot 4")),
2935                (-0.0, None), // No match in WHEN
2936            ]);
2937
2938        test_lookup_eval_with_and_without_else(
2939            &lookup_map,
2940            Arc::new(input_values),
2941            expected,
2942        );
2943    }
2944
2945    // Test that we don't lose the decimal precision and scale info
2946    #[test]
2947    fn test_decimal_with_non_default_precision_and_scale() {
2948        let lookup_map = create_lookup([
2949            (ScalarValue::Decimal32(Some(4), 3, 2), Some("four")),
2950            (ScalarValue::Decimal32(Some(2), 3, 2), Some("two")),
2951            (ScalarValue::Decimal32(Some(3), 3, 2), Some("three")),
2952            (ScalarValue::Decimal32(Some(1), 3, 2), Some("one")),
2953        ]);
2954
2955        let (input_values, expected) =
2956            create_input_and_expected::<Decimal32Array, StringArray, _, _>([
2957                (1, Some("one")),
2958                (2, Some("two")),
2959                (3, Some("three")),
2960                (3, Some("three")),
2961                (2, Some("two")),
2962                (3, Some("three")),
2963                (5, None), // No match in WHEN
2964                (5, None), // No match in WHEN
2965                (3, Some("three")),
2966                (5, None), // No match in WHEN
2967            ]);
2968
2969        let input_values = input_values
2970            .with_precision_and_scale(3, 2)
2971            .expect("must be able to set precision and scale");
2972
2973        test_lookup_eval_with_and_without_else(
2974            &lookup_map,
2975            Arc::new(input_values),
2976            expected,
2977        );
2978    }
2979
2980    // Test that we don't lose the timezone info
2981    #[test]
2982    fn test_timestamp_with_non_default_timezone() {
2983        let timezone: Option<Arc<str>> = Some("-10:00".into());
2984        let lookup_map = create_lookup([
2985            (
2986                ScalarValue::TimestampMillisecond(Some(4), timezone.clone()),
2987                Some("four"),
2988            ),
2989            (
2990                ScalarValue::TimestampMillisecond(Some(2), timezone.clone()),
2991                Some("two"),
2992            ),
2993            (
2994                ScalarValue::TimestampMillisecond(Some(3), timezone.clone()),
2995                Some("three"),
2996            ),
2997            (
2998                ScalarValue::TimestampMillisecond(Some(1), timezone.clone()),
2999                Some("one"),
3000            ),
3001        ]);
3002
3003        let (input_values, expected) =
3004            create_input_and_expected::<TimestampMillisecondArray, StringArray, _, _>([
3005                (1, Some("one")),
3006                (2, Some("two")),
3007                (3, Some("three")),
3008                (3, Some("three")),
3009                (2, Some("two")),
3010                (3, Some("three")),
3011                (5, None), // No match in WHEN
3012                (5, None), // No match in WHEN
3013                (3, Some("three")),
3014                (5, None), // No match in WHEN
3015            ]);
3016
3017        let input_values = input_values.with_timezone_opt(timezone);
3018
3019        test_lookup_eval_with_and_without_else(
3020            &lookup_map,
3021            Arc::new(input_values),
3022            expected,
3023        );
3024    }
3025
3026    #[test]
3027    fn test_with_strings_to_int32() {
3028        let lookup_map = create_lookup([
3029            (Some("why"), Some(42)),
3030            (Some("what"), Some(22)),
3031            (Some("when"), Some(17)),
3032        ]);
3033
3034        let (input_values, expected) =
3035            create_input_and_expected::<StringArray, Int32Array, _, _>([
3036                (Some("why"), Some(42)),
3037                (Some("5"), None), // No match in WHEN
3038                (None, None), // None cases are never match in CASE <expr> WHEN <value> syntax
3039                (Some("what"), Some(22)),
3040                (None, None), // None cases are never match in CASE <expr> WHEN <value> syntax
3041                (None, None), // None cases are never match in CASE <expr> WHEN <value> syntax
3042                (Some("what"), Some(22)),
3043                (Some("5"), None), // No match in WHEN
3044            ]);
3045
3046        let input_values = Arc::new(input_values) as ArrayRef;
3047
3048        // Testing without ELSE should fallback to None
3049        test_case_when_literal_lookup(
3050            Arc::clone(&input_values),
3051            &lookup_map,
3052            None,
3053            Arc::new(expected.clone()),
3054        );
3055
3056        // Testing with Else
3057        let else_value = 101;
3058
3059        // Changing each expected None to be fallback
3060        let expected_with_else = expected
3061            .iter()
3062            .map(|item| item.unwrap_or(else_value))
3063            .map(Some)
3064            .collect::<Int32Array>();
3065
3066        // Test case
3067        test_case_when_literal_lookup(
3068            input_values,
3069            &lookup_map,
3070            Some(ScalarValue::Int32(Some(else_value))),
3071            Arc::new(expected_with_else),
3072        );
3073    }
3074}