Skip to main content

datafusion_physical_expr/expressions/
case.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18mod literal_lookup_table;
19
20use super::{Column, Literal};
21use crate::PhysicalExpr;
22use crate::expressions::{lit, try_cast};
23use arrow::array::*;
24use arrow::compute::kernels::zip::zip;
25use arrow::compute::{
26    FilterBuilder, FilterPredicate, is_not_null, not, nullif, prep_null_mask_filter,
27};
28use arrow::datatypes::{DataType, Schema, UInt32Type, UnionMode};
29use arrow::error::ArrowError;
30use datafusion_common::cast::as_boolean_array;
31use datafusion_common::{
32    DataFusionError, Result, ScalarValue, assert_or_internal_err, exec_err,
33    internal_datafusion_err, internal_err,
34};
35use datafusion_expr::ColumnarValue;
36use indexmap::{IndexMap, IndexSet};
37use std::borrow::Cow;
38use std::hash::Hash;
39use std::{any::Any, sync::Arc};
40
41use crate::expressions::case::literal_lookup_table::LiteralLookupTable;
42use arrow::compute::kernels::merge::{MergeIndex, merge, merge_n};
43use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
44use datafusion_physical_expr_common::datum::compare_with_eq;
45use datafusion_physical_expr_common::utils::scatter;
46use itertools::Itertools;
47use std::fmt::{Debug, Formatter};
48
49pub(super) type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
50
51#[derive(Debug, Hash, PartialEq, Eq)]
52enum EvalMethod {
53    /// CASE WHEN condition THEN result
54    ///      [WHEN ...]
55    ///      [ELSE result]
56    /// END
57    NoExpression(ProjectedCaseBody),
58    /// CASE expression
59    ///     WHEN value THEN result
60    ///     [WHEN ...]
61    ///     [ELSE result]
62    /// END
63    WithExpression(ProjectedCaseBody),
64    /// This is a specialization for a specific use case where we can take a fast path
65    /// for expressions that are infallible and can be cheaply computed for the entire
66    /// record batch rather than just for the rows where the predicate is true.
67    ///
68    /// CASE WHEN condition THEN infallible_expression [ELSE NULL] END
69    InfallibleExprOrNull,
70    /// This is a specialization for a specific use case where we can take a fast path
71    /// if there is just one when/then pair and both the `then` and `else` expressions
72    /// are literal values
73    /// CASE WHEN condition THEN literal ELSE literal END
74    ScalarOrScalar,
75    /// This is a specialization for a specific use case where we can take a fast path
76    /// if there is just one when/then pair, the `then` is an expression, and `else` is either
77    /// an expression, literal NULL or absent.
78    ///
79    /// In contrast to [`EvalMethod::InfallibleExprOrNull`], this specialization can handle fallible
80    /// `then` expressions.
81    ///
82    /// CASE WHEN condition THEN expression [ELSE expression] END
83    ExpressionOrExpression(ProjectedCaseBody),
84
85    /// This is a specialization for [`EvalMethod::WithExpression`] when the value and results are literals
86    ///
87    /// See [`LiteralLookupTable`] for more details
88    WithExprScalarLookupTable(LiteralLookupTable),
89}
90
91/// Implementing hash so we can use `derive` on [`EvalMethod`].
92///
93/// not implementing actual [`Hash`] as it is not dyn compatible so we cannot implement it for
94/// `dyn` [`literal_lookup_table::WhenLiteralIndexMap`].
95///
96/// So implementing empty hash is still valid as the data is derived from `PhysicalExpr` s which are already hashed
97impl Hash for LiteralLookupTable {
98    fn hash<H: std::hash::Hasher>(&self, _state: &mut H) {}
99}
100
101/// Implementing Equal so we can use `derive` on [`EvalMethod`].
102///
103/// not implementing actual [`PartialEq`] as it is not dyn compatible so we cannot implement it for
104/// `dyn` [`literal_lookup_table::WhenLiteralIndexMap`].
105///
106/// So we always return true as the data is derived from `PhysicalExpr` s which are already compared
107impl PartialEq for LiteralLookupTable {
108    fn eq(&self, _other: &Self) -> bool {
109        true
110    }
111}
112
113impl Eq for LiteralLookupTable {}
114
115/// The body of a CASE expression which consists of an optional base expression, the "when/then"
116/// branches and an optional "else" branch.
117#[derive(Debug, Hash, PartialEq, Eq)]
118struct CaseBody {
119    /// Optional base expression that can be compared to literal values in the "when" expressions
120    expr: Option<Arc<dyn PhysicalExpr>>,
121    /// One or more when/then expressions
122    when_then_expr: Vec<WhenThen>,
123    /// Optional "else" expression
124    else_expr: Option<Arc<dyn PhysicalExpr>>,
125}
126
127impl CaseBody {
128    /// Derives a [ProjectedCaseBody] from this [CaseBody].
129    fn project(&self) -> Result<ProjectedCaseBody> {
130        // Determine the set of columns that are used in all the expressions of the case body.
131        let mut used_column_indices = IndexSet::<usize>::new();
132        let mut collect_column_indices = |expr: &Arc<dyn PhysicalExpr>| {
133            expr.apply(|expr| {
134                if let Some(column) = expr.as_any().downcast_ref::<Column>() {
135                    used_column_indices.insert(column.index());
136                }
137                Ok(TreeNodeRecursion::Continue)
138            })
139            .expect("Closure cannot fail");
140        };
141
142        if let Some(e) = &self.expr {
143            collect_column_indices(e);
144        }
145        self.when_then_expr.iter().for_each(|(w, t)| {
146            collect_column_indices(w);
147            collect_column_indices(t);
148        });
149        if let Some(e) = &self.else_expr {
150            collect_column_indices(e);
151        }
152
153        // Construct a mapping from the original column index to the projected column index.
154        let column_index_map = used_column_indices
155            .iter()
156            .enumerate()
157            .map(|(projected, original)| (*original, projected))
158            .collect::<IndexMap<usize, usize>>();
159
160        // Construct the projected body by rewriting each expression from the original body
161        // using the column index mapping.
162        let project = |expr: &Arc<dyn PhysicalExpr>| -> Result<Arc<dyn PhysicalExpr>> {
163            Arc::clone(expr)
164                .transform_down(|e| {
165                    if let Some(column) = e.as_any().downcast_ref::<Column>() {
166                        let original = column.index();
167                        let projected = *column_index_map.get(&original).unwrap();
168                        if projected != original {
169                            return Ok(Transformed::yes(Arc::new(Column::new(
170                                column.name(),
171                                projected,
172                            ))));
173                        }
174                    }
175                    Ok(Transformed::no(e))
176                })
177                .map(|t| t.data)
178        };
179
180        let projected_body = CaseBody {
181            expr: self.expr.as_ref().map(project).transpose()?,
182            when_then_expr: self
183                .when_then_expr
184                .iter()
185                .map(|(e, t)| Ok((project(e)?, project(t)?)))
186                .collect::<Result<Vec<_>>>()?,
187            else_expr: self.else_expr.as_ref().map(project).transpose()?,
188        };
189
190        // Construct the projection vector
191        let projection = column_index_map
192            .iter()
193            .sorted_by_key(|(_, v)| **v)
194            .map(|(k, _)| *k)
195            .collect::<Vec<_>>();
196
197        Ok(ProjectedCaseBody {
198            projection,
199            body: projected_body,
200        })
201    }
202}
203
204/// A derived case body that can be used to evaluate a case expression after projecting
205/// record batches using a projection vector.
206///
207/// This is used to avoid filtering columns that are not used in the
208/// input `RecordBatch` when progressively evaluating a `CASE` expression's
209/// remainder batches. Filtering these columns is wasteful since for a record
210/// batch of `n` rows, filtering requires at worst a copy of `n - 1` values
211/// per array. If these filtered values will never be accessed, the time spent
212/// producing them is better avoided.
213///
214/// For example, if we are evaluating the following case expression that
215/// only references columns B and D:
216///
217/// ```sql
218/// SELECT CASE WHEN B > 10 THEN D ELSE NULL END FROM (VALUES (...)) T(A, B, C, D)
219/// ```
220///
221/// Of the 4 input columns `[A, B, C, D]`, the `CASE` expression only access `B` and `D`.
222/// Filtering `A` and `C` would be unnecessary and wasteful.
223///
224/// If we only retain columns `B` and `D` using `RecordBatch::project` and the projection vector
225/// `[1, 3]`, the indices of these two columns will change to `[0, 1]`. To evaluate the
226/// case expression, it will need to be rewritten from `CASE WHEN B@1 > 10 THEN D@3 ELSE NULL END`
227/// to `CASE WHEN B@0 > 10 THEN D@1 ELSE NULL END`.
228///
229/// The projection vector and the rewritten expression (which only differs from the original in
230/// column reference indices) are held in a `ProjectedCaseBody`.
231#[derive(Debug, Hash, PartialEq, Eq)]
232struct ProjectedCaseBody {
233    projection: Vec<usize>,
234    body: CaseBody,
235}
236
237/// The CASE expression is similar to a series of nested if/else and there are two forms that
238/// can be used. The first form consists of a series of boolean "when" expressions with
239/// corresponding "then" expressions, and an optional "else" expression.
240///
241/// CASE WHEN condition THEN result
242///      [WHEN ...]
243///      [ELSE result]
244/// END
245///
246/// The second form uses a base expression and then a series of "when" clauses that match on a
247/// literal value.
248///
249/// CASE expression
250///     WHEN value THEN result
251///     [WHEN ...]
252///     [ELSE result]
253/// END
254#[derive(Debug, Hash, PartialEq, Eq)]
255pub struct CaseExpr {
256    /// The case expression body
257    body: CaseBody,
258    /// Evaluation method to use
259    eval_method: EvalMethod,
260}
261
262impl std::fmt::Display for CaseExpr {
263    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
264        write!(f, "CASE ")?;
265        if let Some(e) = &self.body.expr {
266            write!(f, "{e} ")?;
267        }
268        for (w, t) in &self.body.when_then_expr {
269            write!(f, "WHEN {w} THEN {t} ")?;
270        }
271        if let Some(e) = &self.body.else_expr {
272            write!(f, "ELSE {e} ")?;
273        }
274        write!(f, "END")
275    }
276}
277
278/// This is a specialization for a specific use case where we can take a fast path
279/// for expressions that are infallible and can be cheaply computed for the entire
280/// record batch rather than just for the rows where the predicate is true. For now,
281/// this is limited to use with Column expressions but could potentially be used for other
282/// expressions in the future
283fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) -> bool {
284    expr.as_any().is::<Column>()
285}
286
287/// Creates a [FilterPredicate] from a boolean array.
288fn create_filter(predicate: &BooleanArray, optimize: bool) -> FilterPredicate {
289    let mut filter_builder = FilterBuilder::new(predicate);
290    if optimize {
291        // Always optimize the filter since we use them multiple times.
292        filter_builder = filter_builder.optimize();
293    }
294    filter_builder.build()
295}
296
297fn multiple_arrays(data_type: &DataType) -> bool {
298    match data_type {
299        DataType::Struct(fields) => {
300            fields.len() > 1
301                || fields.len() == 1 && multiple_arrays(fields[0].data_type())
302        }
303        DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
304        _ => false,
305    }
306}
307
308// This should be removed when https://github.com/apache/arrow-rs/pull/8693
309// is merged and becomes available.
310fn filter_record_batch(
311    record_batch: &RecordBatch,
312    filter: &FilterPredicate,
313) -> std::result::Result<RecordBatch, ArrowError> {
314    let filtered_columns = record_batch
315        .columns()
316        .iter()
317        .map(|a| filter_array(a, filter))
318        .collect::<std::result::Result<Vec<_>, _>>()?;
319    // SAFETY: since we start from a valid RecordBatch, there's no need to revalidate the schema
320    // since the set of columns has not changed.
321    // The input column arrays all had the same length (since they're coming from a valid RecordBatch)
322    // and the filtering them with the same filter will produces a new set of arrays with identical
323    // lengths.
324    unsafe {
325        Ok(RecordBatch::new_unchecked(
326            record_batch.schema(),
327            filtered_columns,
328            filter.count(),
329        ))
330    }
331}
332
333// This function exists purely to be able to use the same call style
334// for `filter_record_batch` and `filter_array` at the point of use.
335// When https://github.com/apache/arrow-rs/pull/8693 is available, replace
336// both with method calls on `FilterPredicate`.
337#[inline(always)]
338fn filter_array(
339    array: &dyn Array,
340    filter: &FilterPredicate,
341) -> std::result::Result<ArrayRef, ArrowError> {
342    filter.filter(array)
343}
344
345/// An index into the partial results array that's more compact than `usize`.
346///
347/// `u32::MAX` is reserved as a special 'none' value. This is used instead of
348/// `Option` to keep the array of indices as compact as possible.
349#[derive(Copy, Clone, PartialEq, Eq)]
350struct PartialResultIndex {
351    index: u32,
352}
353
354const NONE_VALUE: u32 = u32::MAX;
355
356impl PartialResultIndex {
357    /// Returns the 'none' placeholder value.
358    fn none() -> Self {
359        Self { index: NONE_VALUE }
360    }
361
362    fn zero() -> Self {
363        Self { index: 0 }
364    }
365
366    /// Creates a new partial result index.
367    ///
368    /// If the provided value is greater than or equal to `u32::MAX`
369    /// an error will be returned.
370    fn try_new(index: usize) -> Result<Self> {
371        let Ok(index) = u32::try_from(index) else {
372            return internal_err!("Partial result index exceeds limit");
373        };
374
375        assert_or_internal_err!(
376            index != NONE_VALUE,
377            "Partial result index exceeds limit"
378        );
379
380        Ok(Self { index })
381    }
382
383    /// Determines if this index is the 'none' placeholder value or not.
384    fn is_none(&self) -> bool {
385        self.index == NONE_VALUE
386    }
387}
388
389impl MergeIndex for PartialResultIndex {
390    /// Returns `Some(index)` if this value is not the 'none' placeholder, `None` otherwise.
391    fn index(&self) -> Option<usize> {
392        if self.is_none() {
393            None
394        } else {
395            Some(self.index as usize)
396        }
397    }
398}
399
400impl Debug for PartialResultIndex {
401    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
402        if self.is_none() {
403            write!(f, "null")
404        } else {
405            write!(f, "{}", self.index)
406        }
407    }
408}
409
410enum ResultState {
411    /// The final result is an array containing only null values.
412    Empty,
413    /// The final result needs to be computed by merging the data in `arrays`.
414    Partial {
415        // A `Vec` of partial results that should be merged.
416        // `partial_result_indices` contains indexes into this vec.
417        arrays: Vec<ArrayRef>,
418        // Indicates per result row from which array in `partial_results` a value should be taken.
419        indices: Vec<PartialResultIndex>,
420    },
421    /// A single branch matched all input rows. When creating the final result, no further merging
422    /// of partial results is necessary.
423    Complete(ColumnarValue),
424}
425
426/// A builder for constructing result arrays for CASE expressions.
427///
428/// Rather than building a monolithic array containing all results, it maintains a set of
429/// partial result arrays and a mapping that indicates for each row which partial array
430/// contains the result value for that row.
431///
432/// On finish(), the builder will merge all partial results into a single array if necessary.
433/// If all rows evaluated to the same array, that array can be returned directly without
434/// any merging overhead.
435struct ResultBuilder {
436    data_type: DataType,
437    /// The number of rows in the final result.
438    row_count: usize,
439    state: ResultState,
440}
441
442impl ResultBuilder {
443    /// Creates a new ResultBuilder that will produce arrays of the given data type.
444    ///
445    /// The `row_count` parameter indicates the number of rows in the final result.
446    fn new(data_type: &DataType, row_count: usize) -> Self {
447        Self {
448            data_type: data_type.clone(),
449            row_count,
450            state: ResultState::Empty,
451        }
452    }
453
454    /// Adds a result for one branch of the case expression.
455    ///
456    /// `row_indices` should be a [UInt32Array] containing [RecordBatch] relative row indices
457    /// for which `value` contains result values.
458    ///
459    /// If `value` is a scalar, the scalar value will be used as the value for each row in `row_indices`.
460    ///
461    /// If `value` is an array, the values from the array and the indices from `row_indices` will be
462    /// processed pairwise. The lengths of `value` and `row_indices` must match.
463    ///
464    /// The diagram below shows a situation where a when expression matched rows 1 and 4 of the
465    /// record batch. The then expression produced the value array `[A, D]`.
466    /// After adding this result, the result array will have been added to `partial arrays` and
467    /// `partial indices` will have been updated at indexes `1` and `4`.
468    ///
469    /// ```text
470    ///  ┌─────────┐     ┌─────────┐┌───────────┐                            ┌─────────┐┌───────────┐
471    ///  │    C    │     │ 0: None ││┌ 0 ──────┐│                            │ 0: None ││┌ 0 ──────┐│
472    ///  ├─────────┤     ├─────────┤││    A    ││                            ├─────────┤││    A    ││
473    ///  │    D    │     │ 1: None ││└─────────┘│                            │ 1:  2   ││└─────────┘│
474    ///  └─────────┘     ├─────────┤│┌ 1 ──────┐│   add_branch_result(       ├─────────┤│┌ 1 ──────┐│
475    ///   matching       │ 2:  0   │││    B    ││     row indices,           │ 2:  0   │││    B    ││
476    /// 'then' values    ├─────────┤│└─────────┘│     value                  ├─────────┤│└─────────┘│
477    ///                  │ 3: None ││           │   )                        │ 3: None ││┌ 2 ──────┐│
478    ///  ┌─────────┐     ├─────────┤│           │ ─────────────────────────▶ ├─────────┤││    C    ││
479    ///  │    1    │     │ 4: None ││           │                            │ 4:  2   ││├─────────┤│
480    ///  ├─────────┤     ├─────────┤│           │                            ├─────────┤││    D    ││
481    ///  │    4    │     │ 5:  1   ││           │                            │ 5:  1   ││└─────────┘│
482    ///  └─────────┘     └─────────┘└───────────┘                            └─────────┘└───────────┘
483    /// row indices        partial     partial                                 partial     partial
484    ///                    indices     arrays                                  indices     arrays
485    /// ```
486    fn add_branch_result(
487        &mut self,
488        row_indices: &ArrayRef,
489        value: ColumnarValue,
490    ) -> Result<()> {
491        match value {
492            ColumnarValue::Array(a) => {
493                if a.len() != row_indices.len() {
494                    internal_err!("Array length must match row indices length")
495                } else if row_indices.len() == self.row_count {
496                    self.set_complete_result(ColumnarValue::Array(a))
497                } else {
498                    self.add_partial_result(row_indices, a)
499                }
500            }
501            ColumnarValue::Scalar(s) => {
502                if row_indices.len() == self.row_count {
503                    self.set_complete_result(ColumnarValue::Scalar(s))
504                } else {
505                    self.add_partial_result(
506                        row_indices,
507                        s.to_array_of_size(row_indices.len())?,
508                    )
509                }
510            }
511        }
512    }
513
514    /// Adds a partial result array.
515    ///
516    /// This method adds the given array data as a partial result and updates the index mapping
517    /// to indicate that the specified rows should take their values from this array.
518    /// The partial results will be merged into a single array when finish() is called.
519    fn add_partial_result(
520        &mut self,
521        row_indices: &ArrayRef,
522        row_values: ArrayRef,
523    ) -> Result<()> {
524        assert_or_internal_err!(
525            row_indices.null_count() == 0,
526            "Row indices must not contain nulls"
527        );
528
529        match &mut self.state {
530            ResultState::Empty => {
531                let array_index = PartialResultIndex::zero();
532                let mut indices = vec![PartialResultIndex::none(); self.row_count];
533                for row_ix in row_indices.as_primitive::<UInt32Type>().values().iter() {
534                    indices[*row_ix as usize] = array_index;
535                }
536
537                self.state = ResultState::Partial {
538                    arrays: vec![row_values],
539                    indices,
540                };
541
542                Ok(())
543            }
544            ResultState::Partial { arrays, indices } => {
545                let array_index = PartialResultIndex::try_new(arrays.len())?;
546
547                arrays.push(row_values);
548
549                for row_ix in row_indices.as_primitive::<UInt32Type>().values().iter() {
550                    // This is check is only active for debug config because the callers of this method,
551                    // `case_when_with_expr` and `case_when_no_expr`, already ensure that
552                    // they only calculate a value for each row at most once.
553                    #[cfg(debug_assertions)]
554                    assert_or_internal_err!(
555                        indices[*row_ix as usize].is_none(),
556                        "Duplicate value for row {}",
557                        *row_ix
558                    );
559
560                    indices[*row_ix as usize] = array_index;
561                }
562                Ok(())
563            }
564            ResultState::Complete(_) => internal_err!(
565                "Cannot add a partial result when complete result is already set"
566            ),
567        }
568    }
569
570    /// Sets a result that applies to all rows.
571    ///
572    /// This is an optimization for cases where all rows evaluate to the same result.
573    /// When a complete result is set, the builder will return it directly from finish()
574    /// without any merging overhead.
575    fn set_complete_result(&mut self, value: ColumnarValue) -> Result<()> {
576        match &self.state {
577            ResultState::Empty => {
578                self.state = ResultState::Complete(value);
579                Ok(())
580            }
581            ResultState::Partial { .. } => {
582                internal_err!(
583                    "Cannot set a complete result when there are already partial results"
584                )
585            }
586            ResultState::Complete(_) => internal_err!("Complete result already set"),
587        }
588    }
589
590    /// Finishes building the result and returns the final array.
591    fn finish(self) -> Result<ColumnarValue> {
592        match self.state {
593            ResultState::Empty => {
594                // No complete result and no partial results.
595                // This can happen for case expressions with no else branch where no rows
596                // matched.
597                Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
598                    &self.data_type,
599                )?))
600            }
601            ResultState::Partial { arrays, indices } => {
602                // Merge partial results into a single array.
603                let array_refs = arrays.iter().map(|a| a.as_ref()).collect::<Vec<_>>();
604                Ok(ColumnarValue::Array(merge_n(&array_refs, &indices)?))
605            }
606            ResultState::Complete(v) => {
607                // If we have a complete result, we can just return it.
608                Ok(v)
609            }
610        }
611    }
612}
613
614impl CaseExpr {
615    /// Create a new CASE WHEN expression
616    pub fn try_new(
617        expr: Option<Arc<dyn PhysicalExpr>>,
618        when_then_expr: Vec<WhenThen>,
619        else_expr: Option<Arc<dyn PhysicalExpr>>,
620    ) -> Result<Self> {
621        // normalize null literals to None in the else_expr (this already happens
622        // during SQL planning, but not necessarily for other use cases)
623        let else_expr = match &else_expr {
624            Some(e) => match e.as_any().downcast_ref::<Literal>() {
625                Some(lit) if lit.value().is_null() => None,
626                _ => else_expr,
627            },
628            _ => else_expr,
629        };
630
631        if when_then_expr.is_empty() {
632            return exec_err!("There must be at least one WHEN clause");
633        }
634
635        let body = CaseBody {
636            expr,
637            when_then_expr,
638            else_expr,
639        };
640
641        let eval_method = Self::find_best_eval_method(&body)?;
642
643        Ok(Self { body, eval_method })
644    }
645
646    fn find_best_eval_method(body: &CaseBody) -> Result<EvalMethod> {
647        if body.expr.is_some() {
648            if let Some(mapping) = LiteralLookupTable::maybe_new(body) {
649                return Ok(EvalMethod::WithExprScalarLookupTable(mapping));
650            }
651
652            return Ok(EvalMethod::WithExpression(body.project()?));
653        }
654
655        Ok(
656            if body.when_then_expr.len() == 1
657                && is_cheap_and_infallible(&(body.when_then_expr[0].1))
658                && body.else_expr.is_none()
659            {
660                EvalMethod::InfallibleExprOrNull
661            } else if body.when_then_expr.len() == 1
662                && body.when_then_expr[0].1.as_any().is::<Literal>()
663                && body.else_expr.is_some()
664                && body.else_expr.as_ref().unwrap().as_any().is::<Literal>()
665            {
666                EvalMethod::ScalarOrScalar
667            } else if body.when_then_expr.len() == 1 {
668                EvalMethod::ExpressionOrExpression(body.project()?)
669            } else {
670                EvalMethod::NoExpression(body.project()?)
671            },
672        )
673    }
674
675    /// Optional base expression that can be compared to literal values in the "when" expressions
676    pub fn expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
677        self.body.expr.as_ref()
678    }
679
680    /// One or more when/then expressions
681    pub fn when_then_expr(&self) -> &[WhenThen] {
682        &self.body.when_then_expr
683    }
684
685    /// Optional "else" expression
686    pub fn else_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
687        self.body.else_expr.as_ref()
688    }
689}
690
691impl CaseBody {
692    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
693        // since all then results have the same data type, we can choose any one as the
694        // return data type except for the null.
695        let mut data_type = DataType::Null;
696        for i in 0..self.when_then_expr.len() {
697            data_type = self.when_then_expr[i].1.data_type(input_schema)?;
698            if !data_type.equals_datatype(&DataType::Null) {
699                break;
700            }
701        }
702        // if all then results are null, we use data type of else expr instead if possible.
703        if data_type.equals_datatype(&DataType::Null)
704            && let Some(e) = &self.else_expr
705        {
706            data_type = e.data_type(input_schema)?;
707        }
708
709        Ok(data_type)
710    }
711
712    /// See [CaseExpr::case_when_with_expr].
713    fn case_when_with_expr(
714        &self,
715        batch: &RecordBatch,
716        return_type: &DataType,
717    ) -> Result<ColumnarValue> {
718        let mut result_builder = ResultBuilder::new(return_type, batch.num_rows());
719
720        // `remainder_rows` contains the indices of the rows that need to be evaluated
721        let mut remainder_rows: ArrayRef =
722            Arc::new(UInt32Array::from_iter_values(0..batch.num_rows() as u32));
723        // `remainder_batch` contains the rows themselves that need to be evaluated
724        let mut remainder_batch = Cow::Borrowed(batch);
725
726        // evaluate the base expression
727        let mut base_values = self
728            .expr
729            .as_ref()
730            .unwrap()
731            .evaluate(batch)?
732            .into_array(batch.num_rows())?;
733
734        // Fill in a result value already for rows where the base expression value is null
735        // Since each when expression is tested against the base expression using the equality
736        // operator, null base values can never match any when expression. `x = NULL` is falsy,
737        // for all possible values of `x`.
738        let base_null_count = base_values.logical_null_count();
739        if base_null_count > 0 {
740            // Use `is_not_null` since this is a cheap clone of the null buffer from 'base_value'.
741            // We already checked there are nulls, so we can be sure a new buffer will not be
742            // created.
743            let base_not_nulls = is_not_null(base_values.as_ref())?;
744            let base_all_null = base_null_count == remainder_batch.num_rows();
745
746            // If there is an else expression, use that as the default value for the null rows
747            // Otherwise the default `null` value from the result builder will be used.
748            if let Some(e) = &self.else_expr {
749                let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
750
751                if base_all_null {
752                    // All base values were null, so no need to filter
753                    let nulls_value = expr.evaluate(&remainder_batch)?;
754                    result_builder.add_branch_result(&remainder_rows, nulls_value)?;
755                } else {
756                    // Filter out the null rows and evaluate the else expression for those
757                    let nulls_filter = create_filter(&not(&base_not_nulls)?, true);
758                    let nulls_batch =
759                        filter_record_batch(&remainder_batch, &nulls_filter)?;
760                    let nulls_rows = filter_array(&remainder_rows, &nulls_filter)?;
761                    let nulls_value = expr.evaluate(&nulls_batch)?;
762                    result_builder.add_branch_result(&nulls_rows, nulls_value)?;
763                }
764            }
765
766            // All base values are null, so we can return early
767            if base_all_null {
768                return result_builder.finish();
769            }
770
771            // Remove the null rows from the remainder batch
772            let not_null_filter = create_filter(&base_not_nulls, true);
773            remainder_batch =
774                Cow::Owned(filter_record_batch(&remainder_batch, &not_null_filter)?);
775            remainder_rows = filter_array(&remainder_rows, &not_null_filter)?;
776            base_values = filter_array(&base_values, &not_null_filter)?;
777        }
778
779        // The types of case and when expressions will be coerced to match.
780        // We only need to check if the base_value is nested.
781        let base_value_is_nested = base_values.data_type().is_nested();
782
783        for i in 0..self.when_then_expr.len() {
784            // Evaluate the 'when' predicate for the remainder batch
785            // This results in a boolean array with the same length as the remaining number of rows
786            let when_expr = &self.when_then_expr[i].0;
787            let when_value = match when_expr.evaluate(&remainder_batch)? {
788                ColumnarValue::Array(a) => {
789                    compare_with_eq(&a, &base_values, base_value_is_nested)
790                }
791                ColumnarValue::Scalar(s) => {
792                    compare_with_eq(&s.to_scalar()?, &base_values, base_value_is_nested)
793                }
794            }?;
795
796            // `true_count` ignores `true` values where the validity bit is not set, so there's
797            // no need to call `prep_null_mask_filter`.
798            let when_true_count = when_value.true_count();
799
800            // If the 'when' predicate did not match any rows, continue to the next branch immediately
801            if when_true_count == 0 {
802                continue;
803            }
804
805            // If the 'when' predicate matched all remaining rows, there is no need to filter
806            if when_true_count == remainder_batch.num_rows() {
807                let then_expression = &self.when_then_expr[i].1;
808                let then_value = then_expression.evaluate(&remainder_batch)?;
809                result_builder.add_branch_result(&remainder_rows, then_value)?;
810                return result_builder.finish();
811            }
812
813            // Filter the remainder batch based on the 'when' value
814            // This results in a batch containing only the rows that need to be evaluated
815            // for the current branch
816            // Still no need to call `prep_null_mask_filter` since `create_filter` will already do
817            // this unconditionally.
818            let then_filter = create_filter(&when_value, true);
819            let then_batch = filter_record_batch(&remainder_batch, &then_filter)?;
820            let then_rows = filter_array(&remainder_rows, &then_filter)?;
821
822            let then_expression = &self.when_then_expr[i].1;
823            let then_value = then_expression.evaluate(&then_batch)?;
824            result_builder.add_branch_result(&then_rows, then_value)?;
825
826            // If this is the last 'when' branch and there is no 'else' expression, there's no
827            // point in calculating the remaining rows.
828            if self.else_expr.is_none() && i == self.when_then_expr.len() - 1 {
829                return result_builder.finish();
830            }
831
832            // Prepare the next when branch (or the else branch)
833            let next_selection = match when_value.null_count() {
834                0 => not(&when_value),
835                _ => {
836                    // `prep_null_mask_filter` is required to ensure the not operation treats nulls
837                    // as false
838                    not(&prep_null_mask_filter(&when_value))
839                }
840            }?;
841            let next_filter = create_filter(&next_selection, true);
842            remainder_batch =
843                Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?);
844            remainder_rows = filter_array(&remainder_rows, &next_filter)?;
845            base_values = filter_array(&base_values, &next_filter)?;
846        }
847
848        // If we reached this point, some rows were left unmatched.
849        // Check if those need to be evaluated using the 'else' expression.
850        if let Some(e) = &self.else_expr {
851            // keep `else_expr`'s data type and return type consistent
852            let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
853            let else_value = expr.evaluate(&remainder_batch)?;
854            result_builder.add_branch_result(&remainder_rows, else_value)?;
855        }
856
857        result_builder.finish()
858    }
859
860    /// See [CaseExpr::case_when_no_expr].
861    fn case_when_no_expr(
862        &self,
863        batch: &RecordBatch,
864        return_type: &DataType,
865    ) -> Result<ColumnarValue> {
866        let mut result_builder = ResultBuilder::new(return_type, batch.num_rows());
867
868        // `remainder_rows` contains the indices of the rows that need to be evaluated
869        let mut remainder_rows: ArrayRef =
870            Arc::new(UInt32Array::from_iter(0..batch.num_rows() as u32));
871        // `remainder_batch` contains the rows themselves that need to be evaluated
872        let mut remainder_batch = Cow::Borrowed(batch);
873
874        for i in 0..self.when_then_expr.len() {
875            // Evaluate the 'when' predicate for the remainder batch
876            // This results in a boolean array with the same length as the remaining number of rows
877            let when_predicate = &self.when_then_expr[i].0;
878            let when_value = when_predicate
879                .evaluate(&remainder_batch)?
880                .into_array(remainder_batch.num_rows())?;
881            let when_value = as_boolean_array(&when_value).map_err(|_| {
882                internal_datafusion_err!("WHEN expression did not return a BooleanArray")
883            })?;
884
885            // `true_count` ignores `true` values where the validity bit is not set, so there's
886            // no need to call `prep_null_mask_filter`.
887            let when_true_count = when_value.true_count();
888
889            // If the 'when' predicate did not match any rows, continue to the next branch immediately
890            if when_true_count == 0 {
891                continue;
892            }
893
894            // If the 'when' predicate matched all remaining rows, there is no need to filter
895            if when_true_count == remainder_batch.num_rows() {
896                let then_expression = &self.when_then_expr[i].1;
897                let then_value = then_expression.evaluate(&remainder_batch)?;
898                result_builder.add_branch_result(&remainder_rows, then_value)?;
899                return result_builder.finish();
900            }
901
902            // Filter the remainder batch based on the 'when' value
903            // This results in a batch containing only the rows that need to be evaluated
904            // for the current branch
905            // Still no need to call `prep_null_mask_filter` since `create_filter` will already do
906            // this unconditionally.
907            let then_filter = create_filter(when_value, true);
908            let then_batch = filter_record_batch(&remainder_batch, &then_filter)?;
909            let then_rows = filter_array(&remainder_rows, &then_filter)?;
910
911            let then_expression = &self.when_then_expr[i].1;
912            let then_value = then_expression.evaluate(&then_batch)?;
913            result_builder.add_branch_result(&then_rows, then_value)?;
914
915            // If this is the last 'when' branch and there is no 'else' expression, there's no
916            // point in calculating the remaining rows.
917            if self.else_expr.is_none() && i == self.when_then_expr.len() - 1 {
918                return result_builder.finish();
919            }
920
921            // Prepare the next when branch (or the else branch)
922            let next_selection = match when_value.null_count() {
923                0 => not(when_value),
924                _ => {
925                    // `prep_null_mask_filter` is required to ensure the not operation treats nulls
926                    // as false
927                    not(&prep_null_mask_filter(when_value))
928                }
929            }?;
930            let next_filter = create_filter(&next_selection, true);
931            remainder_batch =
932                Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?);
933            remainder_rows = filter_array(&remainder_rows, &next_filter)?;
934        }
935
936        // If we reached this point, some rows were left unmatched.
937        // Check if those need to be evaluated using the 'else' expression.
938        if let Some(e) = &self.else_expr {
939            // keep `else_expr`'s data type and return type consistent
940            let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
941            let else_value = expr.evaluate(&remainder_batch)?;
942            result_builder.add_branch_result(&remainder_rows, else_value)?;
943        }
944
945        result_builder.finish()
946    }
947
948    /// See [CaseExpr::expr_or_expr].
949    fn expr_or_expr(
950        &self,
951        batch: &RecordBatch,
952        when_value: &BooleanArray,
953    ) -> Result<ColumnarValue> {
954        let when_value = match when_value.null_count() {
955            0 => Cow::Borrowed(when_value),
956            _ => {
957                // `prep_null_mask_filter` is required to ensure null is treated as false
958                Cow::Owned(prep_null_mask_filter(when_value))
959            }
960        };
961
962        let optimize_filter = batch.num_columns() > 1
963            || (batch.num_columns() == 1 && multiple_arrays(batch.column(0).data_type()));
964
965        let when_filter = create_filter(&when_value, optimize_filter);
966        let then_batch = filter_record_batch(batch, &when_filter)?;
967        let then_value = self.when_then_expr[0].1.evaluate(&then_batch)?;
968
969        match &self.else_expr {
970            None => {
971                let then_array = then_value.to_array(when_value.true_count())?;
972                scatter(&when_value, then_array.as_ref()).map(ColumnarValue::Array)
973            }
974            Some(else_expr) => {
975                let else_selection = not(&when_value)?;
976                let else_filter = create_filter(&else_selection, optimize_filter);
977                let else_batch = filter_record_batch(batch, &else_filter)?;
978
979                // keep `else_expr`'s data type and return type consistent
980                let return_type = self.data_type(&batch.schema())?;
981                let else_expr =
982                    try_cast(Arc::clone(else_expr), &batch.schema(), return_type.clone())
983                        .unwrap_or_else(|_| Arc::clone(else_expr));
984
985                let else_value = else_expr.evaluate(&else_batch)?;
986
987                Ok(ColumnarValue::Array(match (then_value, else_value) {
988                    (ColumnarValue::Array(t), ColumnarValue::Array(e)) => {
989                        merge(&when_value, &t, &e)
990                    }
991                    (ColumnarValue::Scalar(t), ColumnarValue::Array(e)) => {
992                        merge(&when_value, &t.to_scalar()?, &e)
993                    }
994                    (ColumnarValue::Array(t), ColumnarValue::Scalar(e)) => {
995                        merge(&when_value, &t, &e.to_scalar()?)
996                    }
997                    (ColumnarValue::Scalar(t), ColumnarValue::Scalar(e)) => {
998                        merge(&when_value, &t.to_scalar()?, &e.to_scalar()?)
999                    }
1000                }?))
1001            }
1002        }
1003    }
1004}
1005
1006impl CaseExpr {
1007    /// This function evaluates the form of CASE that matches an expression to fixed values.
1008    ///
1009    /// CASE expression
1010    ///     WHEN value THEN result
1011    ///     [WHEN ...]
1012    ///     [ELSE result]
1013    /// END
1014    fn case_when_with_expr(
1015        &self,
1016        batch: &RecordBatch,
1017        projected: &ProjectedCaseBody,
1018    ) -> Result<ColumnarValue> {
1019        let return_type = self.data_type(&batch.schema())?;
1020        if projected.projection.len() < batch.num_columns() {
1021            let projected_batch = batch.project(&projected.projection)?;
1022            projected
1023                .body
1024                .case_when_with_expr(&projected_batch, &return_type)
1025        } else {
1026            self.body.case_when_with_expr(batch, &return_type)
1027        }
1028    }
1029
1030    /// This function evaluates the form of CASE where each WHEN expression is a boolean
1031    /// expression.
1032    ///
1033    /// CASE WHEN condition THEN result
1034    ///      [WHEN ...]
1035    ///      [ELSE result]
1036    /// END
1037    fn case_when_no_expr(
1038        &self,
1039        batch: &RecordBatch,
1040        projected: &ProjectedCaseBody,
1041    ) -> Result<ColumnarValue> {
1042        let return_type = self.data_type(&batch.schema())?;
1043        if projected.projection.len() < batch.num_columns() {
1044            let projected_batch = batch.project(&projected.projection)?;
1045            projected
1046                .body
1047                .case_when_no_expr(&projected_batch, &return_type)
1048        } else {
1049            self.body.case_when_no_expr(batch, &return_type)
1050        }
1051    }
1052
1053    /// This function evaluates the specialized case of:
1054    ///
1055    /// CASE WHEN condition THEN column
1056    ///      [ELSE NULL]
1057    /// END
1058    ///
1059    /// Note that this function is only safe to use for "then" expressions
1060    /// that are infallible because the expression will be evaluated for all
1061    /// rows in the input batch.
1062    fn case_column_or_null(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
1063        let when_expr = &self.body.when_then_expr[0].0;
1064        let then_expr = &self.body.when_then_expr[0].1;
1065
1066        match when_expr.evaluate(batch)? {
1067            // WHEN true --> column
1068            ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))) => {
1069                then_expr.evaluate(batch)
1070            }
1071            // WHEN [false | null] --> NULL
1072            ColumnarValue::Scalar(_) => {
1073                // return scalar NULL value
1074                ScalarValue::try_from(self.data_type(&batch.schema())?)
1075                    .map(ColumnarValue::Scalar)
1076            }
1077            // WHEN column --> column
1078            ColumnarValue::Array(bit_mask) => {
1079                let bit_mask = bit_mask
1080                    .as_any()
1081                    .downcast_ref::<BooleanArray>()
1082                    .expect("predicate should evaluate to a boolean array");
1083                // invert the bitmask
1084                let bit_mask = match bit_mask.null_count() {
1085                    0 => not(bit_mask)?,
1086                    _ => not(&prep_null_mask_filter(bit_mask))?,
1087                };
1088                match then_expr.evaluate(batch)? {
1089                    ColumnarValue::Array(array) => {
1090                        Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?))
1091                    }
1092                    ColumnarValue::Scalar(_) => {
1093                        internal_err!("expression did not evaluate to an array")
1094                    }
1095                }
1096            }
1097        }
1098    }
1099
1100    fn scalar_or_scalar(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
1101        let return_type = self.data_type(&batch.schema())?;
1102
1103        // evaluate when expression
1104        let when_value = self.body.when_then_expr[0].0.evaluate(batch)?;
1105        let when_value = when_value.into_array(batch.num_rows())?;
1106        let when_value = as_boolean_array(&when_value).map_err(|_| {
1107            internal_datafusion_err!("WHEN expression did not return a BooleanArray")
1108        })?;
1109
1110        // Treat 'NULL' as false value
1111        let when_value = match when_value.null_count() {
1112            0 => Cow::Borrowed(when_value),
1113            _ => Cow::Owned(prep_null_mask_filter(when_value)),
1114        };
1115
1116        // evaluate then_value
1117        let then_value = self.body.when_then_expr[0].1.evaluate(batch)?;
1118        let then_value = Scalar::new(then_value.into_array(1)?);
1119
1120        let Some(e) = &self.body.else_expr else {
1121            return internal_err!("expression did not evaluate to an array");
1122        };
1123        // keep `else_expr`'s data type and return type consistent
1124        let expr = try_cast(Arc::clone(e), &batch.schema(), return_type)?;
1125        let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?);
1126        Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?))
1127    }
1128
1129    fn expr_or_expr(
1130        &self,
1131        batch: &RecordBatch,
1132        projected: &ProjectedCaseBody,
1133    ) -> Result<ColumnarValue> {
1134        // evaluate when condition on batch
1135        let when_value = self.body.when_then_expr[0].0.evaluate(batch)?;
1136        // `num_rows == 1` is intentional to avoid expanding scalars.
1137        // If the `when_value` is effectively a scalar, the 'all true' and 'all false' checks
1138        // below will avoid incorrectly using the scalar as a merge/zip mask.
1139        let when_value = when_value.into_array(1)?;
1140        let when_value = as_boolean_array(&when_value).map_err(|e| {
1141            DataFusionError::Context(
1142                "WHEN expression did not return a BooleanArray".to_string(),
1143                Box::new(e),
1144            )
1145        })?;
1146
1147        let true_count = when_value.true_count();
1148        if true_count == when_value.len() {
1149            // All input rows are true, just call the 'then' expression
1150            self.body.when_then_expr[0].1.evaluate(batch)
1151        } else if true_count == 0 {
1152            // All input rows are false/null, just call the 'else' expression
1153            match &self.body.else_expr {
1154                Some(else_expr) => else_expr.evaluate(batch),
1155                None => {
1156                    let return_type = self.data_type(&batch.schema())?;
1157                    Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
1158                        &return_type,
1159                    )?))
1160                }
1161            }
1162        } else if projected.projection.len() < batch.num_columns() {
1163            // The case expressions do not use all the columns of the input batch.
1164            // Project first to reduce time spent filtering.
1165            let projected_batch = batch.project(&projected.projection)?;
1166            projected.body.expr_or_expr(&projected_batch, when_value)
1167        } else {
1168            // All columns are used in the case expressions, so there is no need to project.
1169            self.body.expr_or_expr(batch, when_value)
1170        }
1171    }
1172
1173    fn with_lookup_table(
1174        &self,
1175        batch: &RecordBatch,
1176        lookup_table: &LiteralLookupTable,
1177    ) -> Result<ColumnarValue> {
1178        let expr = self.body.expr.as_ref().unwrap();
1179        let evaluated_expression = expr.evaluate(batch)?;
1180
1181        let is_scalar = matches!(evaluated_expression, ColumnarValue::Scalar(_));
1182        let evaluated_expression = evaluated_expression.to_array(1)?;
1183
1184        let values = lookup_table.map_keys_to_values(&evaluated_expression)?;
1185
1186        let result = if is_scalar {
1187            ColumnarValue::Scalar(ScalarValue::try_from_array(values.as_ref(), 0)?)
1188        } else {
1189            ColumnarValue::Array(values)
1190        };
1191
1192        Ok(result)
1193    }
1194}
1195
1196impl PhysicalExpr for CaseExpr {
1197    /// Return a reference to Any that can be used for down-casting
1198    fn as_any(&self) -> &dyn Any {
1199        self
1200    }
1201
1202    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
1203        self.body.data_type(input_schema)
1204    }
1205
1206    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
1207        let nullable_then = self
1208            .body
1209            .when_then_expr
1210            .iter()
1211            .filter_map(|(w, t)| {
1212                let is_nullable = match t.nullable(input_schema) {
1213                    // Pass on error determining nullability verbatim
1214                    Err(e) => return Some(Err(e)),
1215                    Ok(n) => n,
1216                };
1217
1218                // Branches with a then expression that is not nullable do not impact the
1219                // nullability of the case expression.
1220                if !is_nullable {
1221                    return None;
1222                }
1223
1224                // For case-with-expression assume all 'then' expressions are reachable
1225                if self.body.expr.is_some() {
1226                    return Some(Ok(()));
1227                }
1228
1229                // For branches with a nullable 'then' expression, try to determine
1230                // if the 'then' expression is ever reachable in the situation where
1231                // it would evaluate to null.
1232
1233                // Replace the `then` expression with `NULL` in the `when` expression
1234                let with_null = match replace_with_null(w, t.as_ref(), input_schema) {
1235                    Err(e) => return Some(Err(e)),
1236                    Ok(e) => e,
1237                };
1238
1239                // Try to const evaluate the modified `when` expression.
1240                let predicate_result = match evaluate_predicate(&with_null) {
1241                    Err(e) => return Some(Err(e)),
1242                    Ok(b) => b,
1243                };
1244
1245                match predicate_result {
1246                    // Evaluation was inconclusive or true, so the 'then' expression is reachable
1247                    None | Some(true) => Some(Ok(())),
1248                    // Evaluation proves the branch will never be taken.
1249                    // The most common pattern for this is `WHEN x IS NOT NULL THEN x`.
1250                    Some(false) => None,
1251                }
1252            })
1253            .next();
1254
1255        if let Some(nullable_then) = nullable_then {
1256            // There is at least one reachable nullable 'then' expression, so the case
1257            // expression itself is nullable.
1258            // Use `Result::map` to propagate the error from `nullable_then` if there is one.
1259            nullable_then.map(|_| true)
1260        } else if let Some(e) = &self.body.else_expr {
1261            // There are no reachable nullable 'then' expressions, so all we still need to
1262            // check is the 'else' expression's nullability.
1263            e.nullable(input_schema)
1264        } else {
1265            // CASE produces NULL if there is no `else` expr
1266            // (aka when none of the `when_then_exprs` match)
1267            Ok(true)
1268        }
1269    }
1270
1271    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
1272        match &self.eval_method {
1273            EvalMethod::WithExpression(p) => {
1274                // this use case evaluates "expr" and then compares the values with the "when"
1275                // values
1276                self.case_when_with_expr(batch, p)
1277            }
1278            EvalMethod::NoExpression(p) => {
1279                // The "when" conditions all evaluate to boolean in this use case and can be
1280                // arbitrary expressions
1281                self.case_when_no_expr(batch, p)
1282            }
1283            EvalMethod::InfallibleExprOrNull => {
1284                // Specialization for CASE WHEN expr THEN column [ELSE NULL] END
1285                self.case_column_or_null(batch)
1286            }
1287            EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch),
1288            EvalMethod::ExpressionOrExpression(p) => self.expr_or_expr(batch, p),
1289            EvalMethod::WithExprScalarLookupTable(lookup_table) => {
1290                self.with_lookup_table(batch, lookup_table)
1291            }
1292        }
1293    }
1294
1295    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
1296        let mut children = vec![];
1297        if let Some(expr) = &self.body.expr {
1298            children.push(expr)
1299        }
1300        self.body.when_then_expr.iter().for_each(|(cond, value)| {
1301            children.push(cond);
1302            children.push(value);
1303        });
1304
1305        if let Some(else_expr) = &self.body.else_expr {
1306            children.push(else_expr)
1307        }
1308        children
1309    }
1310
1311    // For physical CaseExpr, we do not allow modifying children size
1312    fn with_new_children(
1313        self: Arc<Self>,
1314        children: Vec<Arc<dyn PhysicalExpr>>,
1315    ) -> Result<Arc<dyn PhysicalExpr>> {
1316        if children.len() != self.children().len() {
1317            internal_err!("CaseExpr: Wrong number of children")
1318        } else {
1319            let (expr, when_then_expr, else_expr) =
1320                match (self.expr().is_some(), self.body.else_expr.is_some()) {
1321                    (true, true) => (
1322                        Some(&children[0]),
1323                        &children[1..children.len() - 1],
1324                        Some(&children[children.len() - 1]),
1325                    ),
1326                    (true, false) => {
1327                        (Some(&children[0]), &children[1..children.len()], None)
1328                    }
1329                    (false, true) => (
1330                        None,
1331                        &children[0..children.len() - 1],
1332                        Some(&children[children.len() - 1]),
1333                    ),
1334                    (false, false) => (None, &children[0..children.len()], None),
1335                };
1336            Ok(Arc::new(CaseExpr::try_new(
1337                expr.cloned(),
1338                when_then_expr.iter().cloned().tuples().collect(),
1339                else_expr.cloned(),
1340            )?))
1341        }
1342    }
1343
1344    fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1345        write!(f, "CASE ")?;
1346        if let Some(e) = &self.body.expr {
1347            e.fmt_sql(f)?;
1348            write!(f, " ")?;
1349        }
1350
1351        for (w, t) in &self.body.when_then_expr {
1352            write!(f, "WHEN ")?;
1353            w.fmt_sql(f)?;
1354            write!(f, " THEN ")?;
1355            t.fmt_sql(f)?;
1356            write!(f, " ")?;
1357        }
1358
1359        if let Some(e) = &self.body.else_expr {
1360            write!(f, "ELSE ")?;
1361            e.fmt_sql(f)?;
1362            write!(f, " ")?;
1363        }
1364        write!(f, "END")
1365    }
1366}
1367
1368/// Attempts to const evaluate the given `predicate`.
1369/// Returns:
1370/// - `Some(true)` if the predicate evaluates to a truthy value.
1371/// - `Some(false)` if the predicate evaluates to a falsy value.
1372/// - `None` if the predicate could not be evaluated.
1373fn evaluate_predicate(predicate: &Arc<dyn PhysicalExpr>) -> Result<Option<bool>> {
1374    // Create a dummy record with no columns and one row
1375    let batch = RecordBatch::try_new_with_options(
1376        Arc::new(Schema::empty()),
1377        vec![],
1378        &RecordBatchOptions::new().with_row_count(Some(1)),
1379    )?;
1380
1381    // Evaluate the predicate and interpret the result as a boolean
1382    let result = match predicate.evaluate(&batch) {
1383        // An error during evaluation means we couldn't const evaluate the predicate, so return `None`
1384        Err(_) => None,
1385        Ok(ColumnarValue::Array(array)) => Some(
1386            ScalarValue::try_from_array(array.as_ref(), 0)?
1387                .cast_to(&DataType::Boolean)?,
1388        ),
1389        Ok(ColumnarValue::Scalar(scalar)) => Some(scalar.cast_to(&DataType::Boolean)?),
1390    };
1391    Ok(result.map(|v| matches!(v, ScalarValue::Boolean(Some(true)))))
1392}
1393
1394fn replace_with_null(
1395    expr: &Arc<dyn PhysicalExpr>,
1396    expr_to_replace: &dyn PhysicalExpr,
1397    input_schema: &Schema,
1398) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
1399    let with_null = Arc::clone(expr)
1400        .transform_down(|e| {
1401            if e.as_ref().dyn_eq(expr_to_replace) {
1402                let data_type = e.data_type(input_schema)?;
1403                let null_literal = lit(ScalarValue::try_new_null(&data_type)?);
1404                Ok(Transformed::yes(null_literal))
1405            } else {
1406                Ok(Transformed::no(e))
1407            }
1408        })?
1409        .data;
1410    Ok(with_null)
1411}
1412
1413/// Create a CASE expression
1414pub fn case(
1415    expr: Option<Arc<dyn PhysicalExpr>>,
1416    when_thens: Vec<WhenThen>,
1417    else_expr: Option<Arc<dyn PhysicalExpr>>,
1418) -> Result<Arc<dyn PhysicalExpr>> {
1419    Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?))
1420}
1421
1422#[cfg(test)]
1423mod tests {
1424    use super::*;
1425
1426    use crate::expressions;
1427    use crate::expressions::{BinaryExpr, binary, cast, col, is_not_null, lit};
1428    use arrow::buffer::Buffer;
1429    use arrow::datatypes::DataType::Float64;
1430    use arrow::datatypes::Field;
1431    use datafusion_common::cast::{as_float64_array, as_int32_array};
1432    use datafusion_common::plan_err;
1433    use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
1434    use datafusion_expr::type_coercion::binary::comparison_coercion;
1435    use datafusion_expr_common::operator::Operator;
1436    use datafusion_physical_expr_common::physical_expr::fmt_sql;
1437    use half::f16;
1438
1439    #[test]
1440    fn case_with_expr() -> Result<()> {
1441        let batch = case_test_batch()?;
1442        let schema = batch.schema();
1443
1444        // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END
1445        let when1 = lit("foo");
1446        let then1 = lit(123i32);
1447        let when2 = lit("bar");
1448        let then2 = lit(456i32);
1449
1450        let expr = generate_case_when_with_type_coercion(
1451            Some(col("a", &schema)?),
1452            vec![(when1, then1), (when2, then2)],
1453            None,
1454            schema.as_ref(),
1455        )?;
1456        let result = expr
1457            .evaluate(&batch)?
1458            .into_array(batch.num_rows())
1459            .expect("Failed to convert to array");
1460        let result = as_int32_array(&result)?;
1461
1462        let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1463
1464        assert_eq!(expected, result);
1465
1466        Ok(())
1467    }
1468
1469    #[test]
1470    fn case_with_expr_dictionary() -> Result<()> {
1471        let schema = Schema::new(vec![Field::new(
1472            "a",
1473            DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
1474            true,
1475        )]);
1476        let keys = UInt8Array::from(vec![0u8, 1u8, 2u8, 3u8]);
1477        let values = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
1478        let dictionary = DictionaryArray::new(keys, Arc::new(values));
1479        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?;
1480
1481        let schema = batch.schema();
1482
1483        // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END
1484        let when1 = lit("foo");
1485        let then1 = lit(123i32);
1486        let when2 = lit("bar");
1487        let then2 = lit(456i32);
1488
1489        let expr = generate_case_when_with_type_coercion(
1490            Some(col("a", &schema)?),
1491            vec![(when1, then1), (when2, then2)],
1492            None,
1493            schema.as_ref(),
1494        )?;
1495        let result = expr
1496            .evaluate(&batch)?
1497            .into_array(batch.num_rows())
1498            .expect("Failed to convert to array");
1499        let result = as_int32_array(&result)?;
1500
1501        let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1502
1503        assert_eq!(expected, result);
1504
1505        Ok(())
1506    }
1507
1508    // Make sure we are not failing when got literal in case when but input is dictionary encoded
1509    #[test]
1510    fn case_with_expr_primitive_dictionary() -> Result<()> {
1511        let schema = Schema::new(vec![Field::new(
1512            "a",
1513            DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::UInt64)),
1514            true,
1515        )]);
1516        let keys = UInt8Array::from(vec![0u8, 1u8, 2u8, 3u8]);
1517        let values = UInt64Array::from(vec![Some(10), Some(20), None, Some(30)]);
1518        let dictionary = DictionaryArray::new(keys, Arc::new(values));
1519        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?;
1520
1521        let schema = batch.schema();
1522
1523        // CASE a WHEN 10 THEN 123 WHEN 30 THEN 456 END
1524        let when1 = lit(10_u64);
1525        let then1 = lit(123_i32);
1526        let when2 = lit(30_u64);
1527        let then2 = lit(456_i32);
1528
1529        let expr = generate_case_when_with_type_coercion(
1530            Some(col("a", &schema)?),
1531            vec![(when1, then1), (when2, then2)],
1532            None,
1533            schema.as_ref(),
1534        )?;
1535        let result = expr
1536            .evaluate(&batch)?
1537            .into_array(batch.num_rows())
1538            .expect("Failed to convert to array");
1539        let result = as_int32_array(&result)?;
1540
1541        let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1542
1543        assert_eq!(expected, result);
1544
1545        Ok(())
1546    }
1547
1548    // Make sure we are not failing when got literal in case when but input is dictionary encoded
1549    #[test]
1550    fn case_with_expr_boolean_dictionary() -> Result<()> {
1551        let schema = Schema::new(vec![Field::new(
1552            "a",
1553            DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Boolean)),
1554            true,
1555        )]);
1556        let keys = UInt8Array::from(vec![0u8, 1u8, 2u8, 3u8]);
1557        let values = BooleanArray::from(vec![Some(true), Some(false), None, Some(true)]);
1558        let dictionary = DictionaryArray::new(keys, Arc::new(values));
1559        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?;
1560
1561        let schema = batch.schema();
1562
1563        // CASE a WHEN true THEN 123 WHEN false THEN 456 END
1564        let when1 = lit(true);
1565        let then1 = lit(123i32);
1566        let when2 = lit(false);
1567        let then2 = lit(456i32);
1568
1569        let expr = generate_case_when_with_type_coercion(
1570            Some(col("a", &schema)?),
1571            vec![(when1, then1), (when2, then2)],
1572            None,
1573            schema.as_ref(),
1574        )?;
1575        let result = expr
1576            .evaluate(&batch)?
1577            .into_array(batch.num_rows())
1578            .expect("Failed to convert to array");
1579        let result = as_int32_array(&result)?;
1580
1581        let expected = &Int32Array::from(vec![Some(123), Some(456), None, Some(123)]);
1582
1583        assert_eq!(expected, result);
1584
1585        Ok(())
1586    }
1587
1588    #[test]
1589    fn case_with_expr_all_null_dictionary() -> Result<()> {
1590        let schema = Schema::new(vec![Field::new(
1591            "a",
1592            DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
1593            true,
1594        )]);
1595        let keys = UInt8Array::from(vec![2u8, 2u8, 2u8, 2u8]);
1596        let values = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
1597        let dictionary = DictionaryArray::new(keys, Arc::new(values));
1598        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?;
1599
1600        let schema = batch.schema();
1601
1602        // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END
1603        let when1 = lit("foo");
1604        let then1 = lit(123i32);
1605        let when2 = lit("bar");
1606        let then2 = lit(456i32);
1607
1608        let expr = generate_case_when_with_type_coercion(
1609            Some(col("a", &schema)?),
1610            vec![(when1, then1), (when2, then2)],
1611            None,
1612            schema.as_ref(),
1613        )?;
1614        let result = expr
1615            .evaluate(&batch)?
1616            .into_array(batch.num_rows())
1617            .expect("Failed to convert to array");
1618        let result = as_int32_array(&result)?;
1619
1620        let expected = &Int32Array::from(vec![None, None, None, None]);
1621
1622        assert_eq!(expected, result);
1623
1624        Ok(())
1625    }
1626
1627    #[test]
1628    fn case_with_expr_else() -> Result<()> {
1629        let batch = case_test_batch()?;
1630        let schema = batch.schema();
1631
1632        // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 ELSE 999 END
1633        let when1 = lit("foo");
1634        let then1 = lit(123i32);
1635        let when2 = lit("bar");
1636        let then2 = lit(456i32);
1637        let else_value = lit(999i32);
1638
1639        let expr = generate_case_when_with_type_coercion(
1640            Some(col("a", &schema)?),
1641            vec![(when1, then1), (when2, then2)],
1642            Some(else_value),
1643            schema.as_ref(),
1644        )?;
1645        let result = expr
1646            .evaluate(&batch)?
1647            .into_array(batch.num_rows())
1648            .expect("Failed to convert to array");
1649        let result = as_int32_array(&result)?;
1650
1651        let expected =
1652            &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
1653
1654        assert_eq!(expected, result);
1655
1656        Ok(())
1657    }
1658
1659    #[test]
1660    fn case_with_expr_divide_by_zero() -> Result<()> {
1661        let batch = case_test_batch1()?;
1662        let schema = batch.schema();
1663
1664        // CASE a when 0 THEN float64(null) ELSE 25.0 / cast(a, float64)  END
1665        let when1 = lit(0i32);
1666        let then1 = lit(ScalarValue::Float64(None));
1667        let else_value = binary(
1668            lit(25.0f64),
1669            Operator::Divide,
1670            cast(col("a", &schema)?, &batch.schema(), Float64)?,
1671            &batch.schema(),
1672        )?;
1673
1674        let expr = generate_case_when_with_type_coercion(
1675            Some(col("a", &schema)?),
1676            vec![(when1, then1)],
1677            Some(else_value),
1678            schema.as_ref(),
1679        )?;
1680        let result = expr
1681            .evaluate(&batch)?
1682            .into_array(batch.num_rows())
1683            .expect("Failed to convert to array");
1684        let result =
1685            as_float64_array(&result).expect("failed to downcast to Float64Array");
1686
1687        let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
1688
1689        assert_eq!(expected, result);
1690
1691        Ok(())
1692    }
1693
1694    #[test]
1695    fn case_without_expr() -> Result<()> {
1696        let batch = case_test_batch()?;
1697        let schema = batch.schema();
1698
1699        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 END
1700        let when1 = binary(
1701            col("a", &schema)?,
1702            Operator::Eq,
1703            lit("foo"),
1704            &batch.schema(),
1705        )?;
1706        let then1 = lit(123i32);
1707        let when2 = binary(
1708            col("a", &schema)?,
1709            Operator::Eq,
1710            lit("bar"),
1711            &batch.schema(),
1712        )?;
1713        let then2 = lit(456i32);
1714
1715        let expr = generate_case_when_with_type_coercion(
1716            None,
1717            vec![(when1, then1), (when2, then2)],
1718            None,
1719            schema.as_ref(),
1720        )?;
1721        let result = expr
1722            .evaluate(&batch)?
1723            .into_array(batch.num_rows())
1724            .expect("Failed to convert to array");
1725        let result = as_int32_array(&result)?;
1726
1727        let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1728
1729        assert_eq!(expected, result);
1730
1731        Ok(())
1732    }
1733
1734    #[test]
1735    fn case_with_expr_when_null() -> Result<()> {
1736        let batch = case_test_batch()?;
1737        let schema = batch.schema();
1738
1739        // CASE a WHEN NULL THEN 0 WHEN a THEN 123 ELSE 999 END
1740        let when1 = lit(ScalarValue::Utf8(None));
1741        let then1 = lit(0i32);
1742        let when2 = col("a", &schema)?;
1743        let then2 = lit(123i32);
1744        let else_value = lit(999i32);
1745
1746        let expr = generate_case_when_with_type_coercion(
1747            Some(col("a", &schema)?),
1748            vec![(when1, then1), (when2, then2)],
1749            Some(else_value),
1750            schema.as_ref(),
1751        )?;
1752        let result = expr
1753            .evaluate(&batch)?
1754            .into_array(batch.num_rows())
1755            .expect("Failed to convert to array");
1756        let result = as_int32_array(&result)?;
1757
1758        let expected =
1759            &Int32Array::from(vec![Some(123), Some(123), Some(999), Some(123)]);
1760
1761        assert_eq!(expected, result);
1762
1763        Ok(())
1764    }
1765
1766    #[test]
1767    fn case_without_expr_divide_by_zero() -> Result<()> {
1768        let batch = case_test_batch1()?;
1769        let schema = batch.schema();
1770
1771        // CASE WHEN a > 0 THEN 25.0 / cast(a, float64) ELSE float64(null) END
1772        let when1 = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &batch.schema())?;
1773        let then1 = binary(
1774            lit(25.0f64),
1775            Operator::Divide,
1776            cast(col("a", &schema)?, &batch.schema(), Float64)?,
1777            &batch.schema(),
1778        )?;
1779        let x = lit(ScalarValue::Float64(None));
1780
1781        let expr = generate_case_when_with_type_coercion(
1782            None,
1783            vec![(when1, then1)],
1784            Some(x),
1785            schema.as_ref(),
1786        )?;
1787        let result = expr
1788            .evaluate(&batch)?
1789            .into_array(batch.num_rows())
1790            .expect("Failed to convert to array");
1791        let result =
1792            as_float64_array(&result).expect("failed to downcast to Float64Array");
1793
1794        let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
1795
1796        assert_eq!(expected, result);
1797
1798        Ok(())
1799    }
1800
1801    fn case_test_batch1() -> Result<RecordBatch> {
1802        let schema = Schema::new(vec![
1803            Field::new("a", DataType::Int32, true),
1804            Field::new("b", DataType::Int32, true),
1805            Field::new("c", DataType::Int32, true),
1806        ]);
1807        let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]);
1808        let b = Int32Array::from(vec![Some(3), None, Some(14), Some(7)]);
1809        let c = Int32Array::from(vec![Some(0), Some(-3), Some(777), None]);
1810        let batch = RecordBatch::try_new(
1811            Arc::new(schema),
1812            vec![Arc::new(a), Arc::new(b), Arc::new(c)],
1813        )?;
1814        Ok(batch)
1815    }
1816
1817    #[test]
1818    fn case_without_expr_else() -> Result<()> {
1819        let batch = case_test_batch()?;
1820        let schema = batch.schema();
1821
1822        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 999 END
1823        let when1 = binary(
1824            col("a", &schema)?,
1825            Operator::Eq,
1826            lit("foo"),
1827            &batch.schema(),
1828        )?;
1829        let then1 = lit(123i32);
1830        let when2 = binary(
1831            col("a", &schema)?,
1832            Operator::Eq,
1833            lit("bar"),
1834            &batch.schema(),
1835        )?;
1836        let then2 = lit(456i32);
1837        let else_value = lit(999i32);
1838
1839        let expr = generate_case_when_with_type_coercion(
1840            None,
1841            vec![(when1, then1), (when2, then2)],
1842            Some(else_value),
1843            schema.as_ref(),
1844        )?;
1845        let result = expr
1846            .evaluate(&batch)?
1847            .into_array(batch.num_rows())
1848            .expect("Failed to convert to array");
1849        let result = as_int32_array(&result)?;
1850
1851        let expected =
1852            &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
1853
1854        assert_eq!(expected, result);
1855
1856        Ok(())
1857    }
1858
1859    #[test]
1860    fn case_with_type_cast() -> Result<()> {
1861        let batch = case_test_batch()?;
1862        let schema = batch.schema();
1863
1864        // CASE WHEN a = 'foo' THEN 123.3 ELSE 999 END
1865        let when = binary(
1866            col("a", &schema)?,
1867            Operator::Eq,
1868            lit("foo"),
1869            &batch.schema(),
1870        )?;
1871        let then = lit(123.3f64);
1872        let else_value = lit(999i32);
1873
1874        let expr = generate_case_when_with_type_coercion(
1875            None,
1876            vec![(when, then)],
1877            Some(else_value),
1878            schema.as_ref(),
1879        )?;
1880        let result = expr
1881            .evaluate(&batch)?
1882            .into_array(batch.num_rows())
1883            .expect("Failed to convert to array");
1884        let result =
1885            as_float64_array(&result).expect("failed to downcast to Float64Array");
1886
1887        let expected =
1888            &Float64Array::from(vec![Some(123.3), Some(999.0), Some(999.0), Some(999.0)]);
1889
1890        assert_eq!(expected, result);
1891
1892        Ok(())
1893    }
1894
1895    #[test]
1896    fn case_with_matches_and_nulls() -> Result<()> {
1897        let batch = case_test_batch_nulls()?;
1898        let schema = batch.schema();
1899
1900        // SELECT CASE WHEN load4 = 1.77 THEN load4 END
1901        let when = binary(
1902            col("load4", &schema)?,
1903            Operator::Eq,
1904            lit(1.77f64),
1905            &batch.schema(),
1906        )?;
1907        let then = col("load4", &schema)?;
1908
1909        let expr = generate_case_when_with_type_coercion(
1910            None,
1911            vec![(when, then)],
1912            None,
1913            schema.as_ref(),
1914        )?;
1915        let result = expr
1916            .evaluate(&batch)?
1917            .into_array(batch.num_rows())
1918            .expect("Failed to convert to array");
1919        let result =
1920            as_float64_array(&result).expect("failed to downcast to Float64Array");
1921
1922        let expected =
1923            &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
1924
1925        assert_eq!(expected, result);
1926
1927        Ok(())
1928    }
1929
1930    #[test]
1931    fn case_with_scalar_predicate() -> Result<()> {
1932        let batch = case_test_batch_nulls()?;
1933        let schema = batch.schema();
1934
1935        // SELECT CASE WHEN TRUE THEN load4 END
1936        let when = lit(true);
1937        let then = col("load4", &schema)?;
1938        let expr = generate_case_when_with_type_coercion(
1939            None,
1940            vec![(when, then)],
1941            None,
1942            schema.as_ref(),
1943        )?;
1944
1945        // many rows
1946        let result = expr
1947            .evaluate(&batch)?
1948            .into_array(batch.num_rows())
1949            .expect("Failed to convert to array");
1950        let result =
1951            as_float64_array(&result).expect("failed to downcast to Float64Array");
1952        let expected = &Float64Array::from(vec![
1953            Some(1.77),
1954            None,
1955            None,
1956            Some(1.78),
1957            None,
1958            Some(1.77),
1959        ]);
1960        assert_eq!(expected, result);
1961
1962        // one row
1963        let expected = Float64Array::from(vec![Some(1.1)]);
1964        let batch =
1965            RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(expected.clone())])?;
1966        let result = expr
1967            .evaluate(&batch)?
1968            .into_array(batch.num_rows())
1969            .expect("Failed to convert to array");
1970        let result =
1971            as_float64_array(&result).expect("failed to downcast to Float64Array");
1972        assert_eq!(&expected, result);
1973
1974        Ok(())
1975    }
1976
1977    #[test]
1978    fn case_expr_matches_and_nulls() -> Result<()> {
1979        let batch = case_test_batch_nulls()?;
1980        let schema = batch.schema();
1981
1982        // SELECT CASE load4 WHEN 1.77 THEN load4 END
1983        let expr = col("load4", &schema)?;
1984        let when = lit(1.77f64);
1985        let then = col("load4", &schema)?;
1986
1987        let expr = generate_case_when_with_type_coercion(
1988            Some(expr),
1989            vec![(when, then)],
1990            None,
1991            schema.as_ref(),
1992        )?;
1993        let result = expr
1994            .evaluate(&batch)?
1995            .into_array(batch.num_rows())
1996            .expect("Failed to convert to array");
1997        let result =
1998            as_float64_array(&result).expect("failed to downcast to Float64Array");
1999
2000        let expected =
2001            &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
2002
2003        assert_eq!(expected, result);
2004
2005        Ok(())
2006    }
2007
2008    #[test]
2009    fn test_when_null_and_some_cond_else_null() -> Result<()> {
2010        let batch = case_test_batch()?;
2011        let schema = batch.schema();
2012
2013        let when = binary(
2014            Arc::new(Literal::new(ScalarValue::Boolean(None))),
2015            Operator::And,
2016            binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?,
2017            &schema,
2018        )?;
2019        let then = col("a", &schema)?;
2020
2021        // SELECT CASE WHEN (NULL AND a = 'foo') THEN a ELSE NULL END
2022        let expr = Arc::new(CaseExpr::try_new(None, vec![(when, then)], None)?);
2023        let result = expr
2024            .evaluate(&batch)?
2025            .into_array(batch.num_rows())
2026            .expect("Failed to convert to array");
2027        let result = as_string_array(&result);
2028
2029        // all result values should be null
2030        assert_eq!(result.logical_null_count(), batch.num_rows());
2031        Ok(())
2032    }
2033
2034    fn case_test_batch() -> Result<RecordBatch> {
2035        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
2036        let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
2037        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
2038        Ok(batch)
2039    }
2040
2041    // Construct an array that has several NULL values whose
2042    // underlying buffer actually matches the where expr predicate
2043    fn case_test_batch_nulls() -> Result<RecordBatch> {
2044        let load4: Float64Array = vec![
2045            Some(1.77), // 1.77
2046            Some(1.77), // null <-- same value, but will be set to null
2047            Some(1.77), // null <-- same value, but will be set to null
2048            Some(1.78), // 1.78
2049            None,       // null
2050            Some(1.77), // 1.77
2051        ]
2052        .into_iter()
2053        .collect();
2054
2055        let null_buffer = Buffer::from([0b00101001u8]);
2056        let load4 = load4
2057            .into_data()
2058            .into_builder()
2059            .null_bit_buffer(Some(null_buffer))
2060            .build()
2061            .unwrap();
2062        let load4: Float64Array = load4.into();
2063
2064        let batch =
2065            RecordBatch::try_from_iter(vec![("load4", Arc::new(load4) as ArrayRef)])?;
2066        Ok(batch)
2067    }
2068
2069    #[test]
2070    fn case_test_incompatible() -> Result<()> {
2071        // 1 then is int64
2072        // 2 then is boolean
2073        let batch = case_test_batch()?;
2074        let schema = batch.schema();
2075
2076        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN true END
2077        let when1 = binary(
2078            col("a", &schema)?,
2079            Operator::Eq,
2080            lit("foo"),
2081            &batch.schema(),
2082        )?;
2083        let then1 = lit(123i32);
2084        let when2 = binary(
2085            col("a", &schema)?,
2086            Operator::Eq,
2087            lit("bar"),
2088            &batch.schema(),
2089        )?;
2090        let then2 = lit(true);
2091
2092        let expr = generate_case_when_with_type_coercion(
2093            None,
2094            vec![(when1, then1), (when2, then2)],
2095            None,
2096            schema.as_ref(),
2097        );
2098        assert!(expr.is_err());
2099
2100        // then 1 is int32
2101        // then 2 is int64
2102        // else is float
2103        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 1.23 END
2104        let when1 = binary(
2105            col("a", &schema)?,
2106            Operator::Eq,
2107            lit("foo"),
2108            &batch.schema(),
2109        )?;
2110        let then1 = lit(123i32);
2111        let when2 = binary(
2112            col("a", &schema)?,
2113            Operator::Eq,
2114            lit("bar"),
2115            &batch.schema(),
2116        )?;
2117        let then2 = lit(456i64);
2118        let else_expr = lit(1.23f64);
2119
2120        let expr = generate_case_when_with_type_coercion(
2121            None,
2122            vec![(when1, then1), (when2, then2)],
2123            Some(else_expr),
2124            schema.as_ref(),
2125        );
2126        assert!(expr.is_ok());
2127        let result_type = expr.unwrap().data_type(schema.as_ref())?;
2128        assert_eq!(Float64, result_type);
2129        Ok(())
2130    }
2131
2132    #[test]
2133    fn case_eq() -> Result<()> {
2134        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
2135
2136        let when1 = lit("foo");
2137        let then1 = lit(123i32);
2138        let when2 = lit("bar");
2139        let then2 = lit(456i32);
2140        let else_value = lit(999i32);
2141
2142        let expr1 = generate_case_when_with_type_coercion(
2143            Some(col("a", &schema)?),
2144            vec![
2145                (Arc::clone(&when1), Arc::clone(&then1)),
2146                (Arc::clone(&when2), Arc::clone(&then2)),
2147            ],
2148            Some(Arc::clone(&else_value)),
2149            &schema,
2150        )?;
2151
2152        let expr2 = generate_case_when_with_type_coercion(
2153            Some(col("a", &schema)?),
2154            vec![
2155                (Arc::clone(&when1), Arc::clone(&then1)),
2156                (Arc::clone(&when2), Arc::clone(&then2)),
2157            ],
2158            Some(Arc::clone(&else_value)),
2159            &schema,
2160        )?;
2161
2162        let expr3 = generate_case_when_with_type_coercion(
2163            Some(col("a", &schema)?),
2164            vec![(Arc::clone(&when1), Arc::clone(&then1)), (when2, then2)],
2165            None,
2166            &schema,
2167        )?;
2168
2169        let expr4 = generate_case_when_with_type_coercion(
2170            Some(col("a", &schema)?),
2171            vec![(when1, then1)],
2172            Some(else_value),
2173            &schema,
2174        )?;
2175
2176        assert!(expr1.eq(&expr2));
2177        assert!(expr2.eq(&expr1));
2178
2179        assert!(expr2.ne(&expr3));
2180        assert!(expr3.ne(&expr2));
2181
2182        assert!(expr1.ne(&expr4));
2183        assert!(expr4.ne(&expr1));
2184
2185        Ok(())
2186    }
2187
2188    #[test]
2189    fn case_transform() -> Result<()> {
2190        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
2191
2192        let when1 = lit("foo");
2193        let then1 = lit(123i32);
2194        let when2 = lit("bar");
2195        let then2 = lit(456i32);
2196        let else_value = lit(999i32);
2197
2198        let expr = generate_case_when_with_type_coercion(
2199            Some(col("a", &schema)?),
2200            vec![
2201                (Arc::clone(&when1), Arc::clone(&then1)),
2202                (Arc::clone(&when2), Arc::clone(&then2)),
2203            ],
2204            Some(Arc::clone(&else_value)),
2205            &schema,
2206        )?;
2207
2208        let expr2 = Arc::clone(&expr)
2209            .transform(|e| {
2210                let transformed = match e.as_any().downcast_ref::<Literal>() {
2211                    Some(lit_value) => match lit_value.value() {
2212                        ScalarValue::Utf8(Some(str_value)) => {
2213                            Some(lit(str_value.to_uppercase()))
2214                        }
2215                        _ => None,
2216                    },
2217                    _ => None,
2218                };
2219                Ok(if let Some(transformed) = transformed {
2220                    Transformed::yes(transformed)
2221                } else {
2222                    Transformed::no(e)
2223                })
2224            })
2225            .data()
2226            .unwrap();
2227
2228        let expr3 = Arc::clone(&expr)
2229            .transform_down(|e| {
2230                let transformed = match e.as_any().downcast_ref::<Literal>() {
2231                    Some(lit_value) => match lit_value.value() {
2232                        ScalarValue::Utf8(Some(str_value)) => {
2233                            Some(lit(str_value.to_uppercase()))
2234                        }
2235                        _ => None,
2236                    },
2237                    _ => None,
2238                };
2239                Ok(if let Some(transformed) = transformed {
2240                    Transformed::yes(transformed)
2241                } else {
2242                    Transformed::no(e)
2243                })
2244            })
2245            .data()
2246            .unwrap();
2247
2248        assert!(expr.ne(&expr2));
2249        assert!(expr2.eq(&expr3));
2250
2251        Ok(())
2252    }
2253
2254    #[test]
2255    fn test_column_or_null_specialization() -> Result<()> {
2256        // create input data
2257        let mut c1 = Int32Builder::new();
2258        let mut c2 = StringBuilder::new();
2259        for i in 0..1000 {
2260            c1.append_value(i);
2261            if i % 7 == 0 {
2262                c2.append_null();
2263            } else {
2264                c2.append_value(format!("string {i}"));
2265            }
2266        }
2267        let c1 = Arc::new(c1.finish());
2268        let c2 = Arc::new(c2.finish());
2269        let schema = Schema::new(vec![
2270            Field::new("c1", DataType::Int32, true),
2271            Field::new("c2", DataType::Utf8, true),
2272        ]);
2273        let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap();
2274
2275        // CaseWhenExprOrNull should produce same results as CaseExpr
2276        let predicate = Arc::new(BinaryExpr::new(
2277            make_col("c1", 0),
2278            Operator::LtEq,
2279            make_lit_i32(250),
2280        ));
2281        let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?;
2282        assert_eq!(expr.eval_method, EvalMethod::InfallibleExprOrNull);
2283        match expr.evaluate(&batch)? {
2284            ColumnarValue::Array(array) => {
2285                assert_eq!(1000, array.len());
2286                assert_eq!(785, array.null_count());
2287            }
2288            _ => unreachable!(),
2289        }
2290        Ok(())
2291    }
2292
2293    #[test]
2294    fn test_expr_or_expr_specialization() -> Result<()> {
2295        let batch = case_test_batch1()?;
2296        let schema = batch.schema();
2297        let when = binary(
2298            col("a", &schema)?,
2299            Operator::LtEq,
2300            lit(2i32),
2301            &batch.schema(),
2302        )?;
2303        let then = col("b", &schema)?;
2304        let else_expr = col("c", &schema)?;
2305        let expr = CaseExpr::try_new(None, vec![(when, then)], Some(else_expr))?;
2306        assert!(matches!(
2307            expr.eval_method,
2308            EvalMethod::ExpressionOrExpression(_)
2309        ));
2310        let result = expr
2311            .evaluate(&batch)?
2312            .into_array(batch.num_rows())
2313            .expect("Failed to convert to array");
2314        let result = as_int32_array(&result).expect("failed to downcast to Int32Array");
2315
2316        let expected = &Int32Array::from(vec![Some(3), None, Some(777), None]);
2317
2318        assert_eq!(expected, result);
2319        Ok(())
2320    }
2321
2322    fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
2323        Arc::new(Column::new(name, index))
2324    }
2325
2326    fn make_lit_i32(n: i32) -> Arc<dyn PhysicalExpr> {
2327        Arc::new(Literal::new(ScalarValue::Int32(Some(n))))
2328    }
2329
2330    fn generate_case_when_with_type_coercion(
2331        expr: Option<Arc<dyn PhysicalExpr>>,
2332        when_thens: Vec<WhenThen>,
2333        else_expr: Option<Arc<dyn PhysicalExpr>>,
2334        input_schema: &Schema,
2335    ) -> Result<Arc<dyn PhysicalExpr>> {
2336        let coerce_type =
2337            get_case_common_type(&when_thens, else_expr.clone(), input_schema);
2338        let (when_thens, else_expr) = match coerce_type {
2339            None => plan_err!(
2340                "Can't get a common type for then {when_thens:?} and else {else_expr:?} expression"
2341            ),
2342            Some(data_type) => {
2343                // cast then expr
2344                let left = when_thens
2345                    .into_iter()
2346                    .map(|(when, then)| {
2347                        let then = try_cast(then, input_schema, data_type.clone())?;
2348                        Ok((when, then))
2349                    })
2350                    .collect::<Result<Vec<_>>>()?;
2351                let right = match else_expr {
2352                    None => None,
2353                    Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?),
2354                };
2355
2356                Ok((left, right))
2357            }
2358        }?;
2359        case(expr, when_thens, else_expr)
2360    }
2361
2362    fn get_case_common_type(
2363        when_thens: &[WhenThen],
2364        else_expr: Option<Arc<dyn PhysicalExpr>>,
2365        input_schema: &Schema,
2366    ) -> Option<DataType> {
2367        let thens_type = when_thens
2368            .iter()
2369            .map(|when_then| {
2370                let data_type = &when_then.1.data_type(input_schema).unwrap();
2371                data_type.clone()
2372            })
2373            .collect::<Vec<_>>();
2374        let else_type = match else_expr {
2375            None => {
2376                // case when then exprs must have one then value
2377                thens_type[0].clone()
2378            }
2379            Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(),
2380        };
2381        thens_type
2382            .iter()
2383            .try_fold(else_type, |left_type, right_type| {
2384                // TODO: now just use the `equal` coercion rule for case when. If find the issue, and
2385                // refactor again.
2386                comparison_coercion(&left_type, right_type)
2387            })
2388    }
2389
2390    #[test]
2391    fn test_fmt_sql() -> Result<()> {
2392        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
2393
2394        // CASE WHEN a = 'foo' THEN 123.3 ELSE 999 END
2395        let when = binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?;
2396        let then = lit(123.3f64);
2397        let else_value = lit(999i32);
2398
2399        let expr = generate_case_when_with_type_coercion(
2400            None,
2401            vec![(when, then)],
2402            Some(else_value),
2403            &schema,
2404        )?;
2405
2406        let display_string = expr.to_string();
2407        assert_eq!(
2408            display_string,
2409            "CASE WHEN a@0 = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END"
2410        );
2411
2412        let sql_string = fmt_sql(expr.as_ref()).to_string();
2413        assert_eq!(
2414            sql_string,
2415            "CASE WHEN a = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END"
2416        );
2417
2418        Ok(())
2419    }
2420
2421    fn when_then_else(
2422        when: &Arc<dyn PhysicalExpr>,
2423        then: &Arc<dyn PhysicalExpr>,
2424        els: &Arc<dyn PhysicalExpr>,
2425    ) -> Result<Arc<dyn PhysicalExpr>> {
2426        let case = CaseExpr::try_new(
2427            None,
2428            vec![(Arc::clone(when), Arc::clone(then))],
2429            Some(Arc::clone(els)),
2430        )?;
2431        Ok(Arc::new(case))
2432    }
2433
2434    #[test]
2435    fn test_case_expression_nullability_with_nullable_column() -> Result<()> {
2436        case_expression_nullability(true)
2437    }
2438
2439    #[test]
2440    fn test_case_expression_nullability_with_not_nullable_column() -> Result<()> {
2441        case_expression_nullability(false)
2442    }
2443
2444    fn case_expression_nullability(col_is_nullable: bool) -> Result<()> {
2445        let schema =
2446            Schema::new(vec![Field::new("foo", DataType::Int32, col_is_nullable)]);
2447
2448        let foo = col("foo", &schema)?;
2449        let foo_is_not_null = is_not_null(Arc::clone(&foo))?;
2450        let foo_is_null = expressions::is_null(Arc::clone(&foo))?;
2451        let not_foo_is_null = expressions::not(Arc::clone(&foo_is_null))?;
2452        let zero = lit(0);
2453        let foo_eq_zero =
2454            binary(Arc::clone(&foo), Operator::Eq, Arc::clone(&zero), &schema)?;
2455
2456        assert_not_nullable(when_then_else(&foo_is_not_null, &foo, &zero)?, &schema);
2457        assert_not_nullable(when_then_else(&not_foo_is_null, &foo, &zero)?, &schema);
2458        assert_not_nullable(when_then_else(&foo_eq_zero, &foo, &zero)?, &schema);
2459
2460        assert_not_nullable(
2461            when_then_else(
2462                &binary(
2463                    Arc::clone(&foo_is_not_null),
2464                    Operator::And,
2465                    Arc::clone(&foo_eq_zero),
2466                    &schema,
2467                )?,
2468                &foo,
2469                &zero,
2470            )?,
2471            &schema,
2472        );
2473
2474        assert_not_nullable(
2475            when_then_else(
2476                &binary(
2477                    Arc::clone(&foo_eq_zero),
2478                    Operator::And,
2479                    Arc::clone(&foo_is_not_null),
2480                    &schema,
2481                )?,
2482                &foo,
2483                &zero,
2484            )?,
2485            &schema,
2486        );
2487
2488        assert_not_nullable(
2489            when_then_else(
2490                &binary(
2491                    Arc::clone(&foo_is_not_null),
2492                    Operator::Or,
2493                    Arc::clone(&foo_eq_zero),
2494                    &schema,
2495                )?,
2496                &foo,
2497                &zero,
2498            )?,
2499            &schema,
2500        );
2501
2502        assert_not_nullable(
2503            when_then_else(
2504                &binary(
2505                    Arc::clone(&foo_eq_zero),
2506                    Operator::Or,
2507                    Arc::clone(&foo_is_not_null),
2508                    &schema,
2509                )?,
2510                &foo,
2511                &zero,
2512            )?,
2513            &schema,
2514        );
2515
2516        assert_nullability(
2517            when_then_else(
2518                &binary(
2519                    Arc::clone(&foo_is_null),
2520                    Operator::Or,
2521                    Arc::clone(&foo_eq_zero),
2522                    &schema,
2523                )?,
2524                &foo,
2525                &zero,
2526            )?,
2527            &schema,
2528            col_is_nullable,
2529        );
2530
2531        assert_nullability(
2532            when_then_else(
2533                &binary(
2534                    binary(Arc::clone(&foo), Operator::Eq, Arc::clone(&zero), &schema)?,
2535                    Operator::Or,
2536                    Arc::clone(&foo_is_null),
2537                    &schema,
2538                )?,
2539                &foo,
2540                &zero,
2541            )?,
2542            &schema,
2543            col_is_nullable,
2544        );
2545
2546        assert_not_nullable(
2547            when_then_else(
2548                &binary(
2549                    binary(
2550                        binary(
2551                            Arc::clone(&foo),
2552                            Operator::Eq,
2553                            Arc::clone(&zero),
2554                            &schema,
2555                        )?,
2556                        Operator::And,
2557                        Arc::clone(&foo_is_not_null),
2558                        &schema,
2559                    )?,
2560                    Operator::Or,
2561                    binary(
2562                        binary(
2563                            Arc::clone(&foo),
2564                            Operator::Eq,
2565                            Arc::clone(&foo),
2566                            &schema,
2567                        )?,
2568                        Operator::And,
2569                        Arc::clone(&foo_is_not_null),
2570                        &schema,
2571                    )?,
2572                    &schema,
2573                )?,
2574                &foo,
2575                &zero,
2576            )?,
2577            &schema,
2578        );
2579
2580        Ok(())
2581    }
2582
2583    fn assert_not_nullable(expr: Arc<dyn PhysicalExpr>, schema: &Schema) {
2584        assert!(!expr.nullable(schema).unwrap());
2585    }
2586
2587    fn assert_nullable(expr: Arc<dyn PhysicalExpr>, schema: &Schema) {
2588        assert!(expr.nullable(schema).unwrap());
2589    }
2590
2591    fn assert_nullability(expr: Arc<dyn PhysicalExpr>, schema: &Schema, nullable: bool) {
2592        if nullable {
2593            assert_nullable(expr, schema);
2594        } else {
2595            assert_not_nullable(expr, schema);
2596        }
2597    }
2598
2599    // Test Lookup evaluation
2600
2601    fn test_case_when_literal_lookup(
2602        values: ArrayRef,
2603        lookup_map: &[(ScalarValue, ScalarValue)],
2604        else_value: Option<ScalarValue>,
2605        expected: ArrayRef,
2606    ) {
2607        // Create lookup
2608        // CASE <expr>
2609        // WHEN <when_constant_1> THEN <then_constant_1>
2610        // WHEN <when_constant_2> THEN <then_constant_2>
2611        // [ ELSE <else_constant> ]
2612
2613        let schema = Schema::new(vec![Field::new(
2614            "a",
2615            values.data_type().clone(),
2616            values.is_nullable(),
2617        )]);
2618        let schema = Arc::new(schema);
2619
2620        let batch = RecordBatch::try_new(schema, vec![values])
2621            .expect("failed to create RecordBatch");
2622
2623        let schema = batch.schema_ref();
2624        let case = col("a", schema).expect("failed to create col");
2625
2626        let when_then = lookup_map
2627            .iter()
2628            .map(|(when, then)| {
2629                (
2630                    Arc::new(Literal::new(when.clone())) as _,
2631                    Arc::new(Literal::new(then.clone())) as _,
2632                )
2633            })
2634            .collect::<Vec<WhenThen>>();
2635
2636        let else_expr = else_value.map(|else_value| {
2637            Arc::new(Literal::new(else_value)) as Arc<dyn PhysicalExpr>
2638        });
2639        let expr = CaseExpr::try_new(Some(case), when_then, else_expr)
2640            .expect("failed to create case");
2641
2642        // Assert that we are testing what we intend to assert
2643        assert!(
2644            matches!(
2645                expr.eval_method,
2646                EvalMethod::WithExprScalarLookupTable { .. }
2647            ),
2648            "we should use the expected eval method"
2649        );
2650
2651        let actual = expr
2652            .evaluate(&batch)
2653            .expect("failed to evaluate case")
2654            .into_array(batch.num_rows())
2655            .expect("Failed to convert to array");
2656
2657        assert_eq!(
2658            actual.data_type(),
2659            expected.data_type(),
2660            "Data type mismatch"
2661        );
2662
2663        assert_eq!(
2664            actual.as_ref(),
2665            expected.as_ref(),
2666            "actual (left) does not match expected (right)"
2667        );
2668    }
2669
2670    fn create_lookup<When, Then>(
2671        when_then_pairs: impl IntoIterator<Item = (When, Then)>,
2672    ) -> Vec<(ScalarValue, ScalarValue)>
2673    where
2674        ScalarValue: From<When>,
2675        ScalarValue: From<Then>,
2676    {
2677        when_then_pairs
2678            .into_iter()
2679            .map(|(when, then)| (ScalarValue::from(when), ScalarValue::from(then)))
2680            .collect()
2681    }
2682
2683    fn create_input_and_expected<Input, Expected, InputFromItem, ExpectedFromItem>(
2684        input_and_expected_pairs: impl IntoIterator<Item = (InputFromItem, ExpectedFromItem)>,
2685    ) -> (Input, Expected)
2686    where
2687        Input: Array + From<Vec<InputFromItem>>,
2688        Expected: Array + From<Vec<ExpectedFromItem>>,
2689    {
2690        let (input_items, expected_items): (Vec<InputFromItem>, Vec<ExpectedFromItem>) =
2691            input_and_expected_pairs.into_iter().unzip();
2692
2693        (Input::from(input_items), Expected::from(expected_items))
2694    }
2695
2696    fn test_lookup_eval_with_and_without_else(
2697        lookup_map: &[(ScalarValue, ScalarValue)],
2698        input_values: ArrayRef,
2699        expected: StringArray,
2700    ) {
2701        // Testing without ELSE should fallback to None
2702        test_case_when_literal_lookup(
2703            Arc::clone(&input_values),
2704            lookup_map,
2705            None,
2706            Arc::new(expected.clone()),
2707        );
2708
2709        // Testing with Else
2710        let else_value = "___fallback___";
2711
2712        // Changing each expected None to be fallback
2713        let expected_with_else = expected
2714            .iter()
2715            .map(|item| item.unwrap_or(else_value))
2716            .map(Some)
2717            .collect::<StringArray>();
2718
2719        // Test case
2720        test_case_when_literal_lookup(
2721            input_values,
2722            lookup_map,
2723            Some(ScalarValue::Utf8(Some(else_value.to_string()))),
2724            Arc::new(expected_with_else),
2725        );
2726    }
2727
2728    #[test]
2729    fn test_case_when_literal_lookup_int32_to_string() {
2730        let lookup_map = create_lookup([
2731            (Some(4), Some("four")),
2732            (Some(2), Some("two")),
2733            (Some(3), Some("three")),
2734            (Some(1), Some("one")),
2735        ]);
2736
2737        let (input_values, expected) =
2738            create_input_and_expected::<Int32Array, StringArray, _, _>([
2739                (1, Some("one")),
2740                (2, Some("two")),
2741                (3, Some("three")),
2742                (3, Some("three")),
2743                (2, Some("two")),
2744                (3, Some("three")),
2745                (5, None), // No match in WHEN
2746                (5, None), // No match in WHEN
2747                (3, Some("three")),
2748                (5, None), // No match in WHEN
2749            ]);
2750
2751        test_lookup_eval_with_and_without_else(
2752            &lookup_map,
2753            Arc::new(input_values),
2754            expected,
2755        );
2756    }
2757
2758    #[test]
2759    fn test_case_when_literal_lookup_none_case_should_never_match() {
2760        let lookup_map = create_lookup([
2761            (Some(4), Some("four")),
2762            (None, Some("none")),
2763            (Some(2), Some("two")),
2764            (Some(1), Some("one")),
2765        ]);
2766
2767        let (input_values, expected) =
2768            create_input_and_expected::<Int32Array, StringArray, _, _>([
2769                (Some(1), Some("one")),
2770                (Some(5), None), // No match in WHEN
2771                (None, None), // None cases are never match in CASE <expr> WHEN <value> syntax
2772                (Some(2), Some("two")),
2773                (None, None), // None cases are never match in CASE <expr> WHEN <value> syntax
2774                (None, None), // None cases are never match in CASE <expr> WHEN <value> syntax
2775                (Some(2), Some("two")),
2776                (Some(5), None), // No match in WHEN
2777            ]);
2778
2779        test_lookup_eval_with_and_without_else(
2780            &lookup_map,
2781            Arc::new(input_values),
2782            expected,
2783        );
2784    }
2785
2786    #[test]
2787    fn test_case_when_literal_lookup_int32_to_string_with_duplicate_cases() {
2788        let lookup_map = create_lookup([
2789            (Some(4), Some("four")),
2790            (Some(4), Some("no 4")),
2791            (Some(2), Some("two")),
2792            (Some(2), Some("no 2")),
2793            (Some(3), Some("three")),
2794            (Some(3), Some("no 3")),
2795            (Some(2), Some("no 2")),
2796            (Some(4), Some("no 4")),
2797            (Some(2), Some("no 2")),
2798            (Some(3), Some("no 3")),
2799            (Some(4), Some("no 4")),
2800            (Some(2), Some("no 2")),
2801            (Some(3), Some("no 3")),
2802            (Some(3), Some("no 3")),
2803        ]);
2804
2805        let (input_values, expected) =
2806            create_input_and_expected::<Int32Array, StringArray, _, _>([
2807                (1, None), // No match in WHEN
2808                (2, Some("two")),
2809                (3, Some("three")),
2810                (3, Some("three")),
2811                (2, Some("two")),
2812                (3, Some("three")),
2813                (5, None), // No match in WHEN
2814                (5, None), // No match in WHEN
2815                (3, Some("three")),
2816                (5, None), // No match in WHEN
2817            ]);
2818
2819        test_lookup_eval_with_and_without_else(
2820            &lookup_map,
2821            Arc::new(input_values),
2822            expected,
2823        );
2824    }
2825
2826    #[test]
2827    fn test_case_when_literal_lookup_f32_to_string_with_special_values_and_duplicate_cases()
2828     {
2829        let lookup_map = create_lookup([
2830            (Some(4.0), Some("four point zero")),
2831            (Some(f32::NAN), Some("NaN")),
2832            (Some(3.2), Some("three point two")),
2833            // Duplicate case to make sure it is not used
2834            (Some(f32::NAN), Some("should not use this NaN branch")),
2835            (Some(f32::INFINITY), Some("Infinity")),
2836            (Some(0.0), Some("zero")),
2837            // Duplicate case to make sure it is not used
2838            (
2839                Some(f32::INFINITY),
2840                Some("should not use this Infinity branch"),
2841            ),
2842            (Some(1.1), Some("one point one")),
2843        ]);
2844
2845        let (input_values, expected) =
2846            create_input_and_expected::<Float32Array, StringArray, _, _>([
2847                (1.1, Some("one point one")),
2848                (f32::NAN, Some("NaN")),
2849                (3.2, Some("three point two")),
2850                (3.2, Some("three point two")),
2851                (0.0, Some("zero")),
2852                (f32::INFINITY, Some("Infinity")),
2853                (3.2, Some("three point two")),
2854                (f32::NEG_INFINITY, None), // No match in WHEN
2855                (f32::NEG_INFINITY, None), // No match in WHEN
2856                (3.2, Some("three point two")),
2857                (-0.0, None), // No match in WHEN
2858            ]);
2859
2860        test_lookup_eval_with_and_without_else(
2861            &lookup_map,
2862            Arc::new(input_values),
2863            expected,
2864        );
2865    }
2866
2867    #[test]
2868    fn test_case_when_literal_lookup_f16_to_string_with_special_values() {
2869        let lookup_map = create_lookup([
2870            (
2871                ScalarValue::Float16(Some(f16::from_f32(3.2))),
2872                Some("3 dot 2"),
2873            ),
2874            (ScalarValue::Float16(Some(f16::NAN)), Some("NaN")),
2875            (
2876                ScalarValue::Float16(Some(f16::from_f32(17.4))),
2877                Some("17 dot 4"),
2878            ),
2879            (ScalarValue::Float16(Some(f16::INFINITY)), Some("Infinity")),
2880            (ScalarValue::Float16(Some(f16::ZERO)), Some("zero")),
2881        ]);
2882
2883        let (input_values, expected) =
2884            create_input_and_expected::<Float16Array, StringArray, _, _>([
2885                (f16::from_f32(3.2), Some("3 dot 2")),
2886                (f16::NAN, Some("NaN")),
2887                (f16::from_f32(17.4), Some("17 dot 4")),
2888                (f16::from_f32(17.4), Some("17 dot 4")),
2889                (f16::INFINITY, Some("Infinity")),
2890                (f16::from_f32(17.4), Some("17 dot 4")),
2891                (f16::NEG_INFINITY, None), // No match in WHEN
2892                (f16::NEG_INFINITY, None), // No match in WHEN
2893                (f16::from_f32(17.4), Some("17 dot 4")),
2894                (f16::NEG_ZERO, None), // No match in WHEN
2895            ]);
2896
2897        test_lookup_eval_with_and_without_else(
2898            &lookup_map,
2899            Arc::new(input_values),
2900            expected,
2901        );
2902    }
2903
2904    #[test]
2905    fn test_case_when_literal_lookup_f32_to_string_with_special_values() {
2906        let lookup_map = create_lookup([
2907            (3.2, Some("3 dot 2")),
2908            (f32::NAN, Some("NaN")),
2909            (17.4, Some("17 dot 4")),
2910            (f32::INFINITY, Some("Infinity")),
2911            (f32::ZERO, Some("zero")),
2912        ]);
2913
2914        let (input_values, expected) =
2915            create_input_and_expected::<Float32Array, StringArray, _, _>([
2916                (3.2, Some("3 dot 2")),
2917                (f32::NAN, Some("NaN")),
2918                (17.4, Some("17 dot 4")),
2919                (17.4, Some("17 dot 4")),
2920                (f32::INFINITY, Some("Infinity")),
2921                (17.4, Some("17 dot 4")),
2922                (f32::NEG_INFINITY, None), // No match in WHEN
2923                (f32::NEG_INFINITY, None), // No match in WHEN
2924                (17.4, Some("17 dot 4")),
2925                (-0.0, None), // No match in WHEN
2926            ]);
2927
2928        test_lookup_eval_with_and_without_else(
2929            &lookup_map,
2930            Arc::new(input_values),
2931            expected,
2932        );
2933    }
2934
2935    #[test]
2936    fn test_case_when_literal_lookup_f64_to_string_with_special_values() {
2937        let lookup_map = create_lookup([
2938            (3.2, Some("3 dot 2")),
2939            (f64::NAN, Some("NaN")),
2940            (17.4, Some("17 dot 4")),
2941            (f64::INFINITY, Some("Infinity")),
2942            (f64::ZERO, Some("zero")),
2943        ]);
2944
2945        let (input_values, expected) =
2946            create_input_and_expected::<Float64Array, StringArray, _, _>([
2947                (3.2, Some("3 dot 2")),
2948                (f64::NAN, Some("NaN")),
2949                (17.4, Some("17 dot 4")),
2950                (17.4, Some("17 dot 4")),
2951                (f64::INFINITY, Some("Infinity")),
2952                (17.4, Some("17 dot 4")),
2953                (f64::NEG_INFINITY, None), // No match in WHEN
2954                (f64::NEG_INFINITY, None), // No match in WHEN
2955                (17.4, Some("17 dot 4")),
2956                (-0.0, None), // No match in WHEN
2957            ]);
2958
2959        test_lookup_eval_with_and_without_else(
2960            &lookup_map,
2961            Arc::new(input_values),
2962            expected,
2963        );
2964    }
2965
2966    // Test that we don't lose the decimal precision and scale info
2967    #[test]
2968    fn test_decimal_with_non_default_precision_and_scale() {
2969        let lookup_map = create_lookup([
2970            (ScalarValue::Decimal32(Some(4), 3, 2), Some("four")),
2971            (ScalarValue::Decimal32(Some(2), 3, 2), Some("two")),
2972            (ScalarValue::Decimal32(Some(3), 3, 2), Some("three")),
2973            (ScalarValue::Decimal32(Some(1), 3, 2), Some("one")),
2974        ]);
2975
2976        let (input_values, expected) =
2977            create_input_and_expected::<Decimal32Array, StringArray, _, _>([
2978                (1, Some("one")),
2979                (2, Some("two")),
2980                (3, Some("three")),
2981                (3, Some("three")),
2982                (2, Some("two")),
2983                (3, Some("three")),
2984                (5, None), // No match in WHEN
2985                (5, None), // No match in WHEN
2986                (3, Some("three")),
2987                (5, None), // No match in WHEN
2988            ]);
2989
2990        let input_values = input_values
2991            .with_precision_and_scale(3, 2)
2992            .expect("must be able to set precision and scale");
2993
2994        test_lookup_eval_with_and_without_else(
2995            &lookup_map,
2996            Arc::new(input_values),
2997            expected,
2998        );
2999    }
3000
3001    // Test that we don't lose the timezone info
3002    #[test]
3003    fn test_timestamp_with_non_default_timezone() {
3004        let timezone: Option<Arc<str>> = Some("-10:00".into());
3005        let lookup_map = create_lookup([
3006            (
3007                ScalarValue::TimestampMillisecond(Some(4), timezone.clone()),
3008                Some("four"),
3009            ),
3010            (
3011                ScalarValue::TimestampMillisecond(Some(2), timezone.clone()),
3012                Some("two"),
3013            ),
3014            (
3015                ScalarValue::TimestampMillisecond(Some(3), timezone.clone()),
3016                Some("three"),
3017            ),
3018            (
3019                ScalarValue::TimestampMillisecond(Some(1), timezone.clone()),
3020                Some("one"),
3021            ),
3022        ]);
3023
3024        let (input_values, expected) =
3025            create_input_and_expected::<TimestampMillisecondArray, StringArray, _, _>([
3026                (1, Some("one")),
3027                (2, Some("two")),
3028                (3, Some("three")),
3029                (3, Some("three")),
3030                (2, Some("two")),
3031                (3, Some("three")),
3032                (5, None), // No match in WHEN
3033                (5, None), // No match in WHEN
3034                (3, Some("three")),
3035                (5, None), // No match in WHEN
3036            ]);
3037
3038        let input_values = input_values.with_timezone_opt(timezone);
3039
3040        test_lookup_eval_with_and_without_else(
3041            &lookup_map,
3042            Arc::new(input_values),
3043            expected,
3044        );
3045    }
3046
3047    #[test]
3048    fn test_with_strings_to_int32() {
3049        let lookup_map = create_lookup([
3050            (Some("why"), Some(42)),
3051            (Some("what"), Some(22)),
3052            (Some("when"), Some(17)),
3053        ]);
3054
3055        let (input_values, expected) =
3056            create_input_and_expected::<StringArray, Int32Array, _, _>([
3057                (Some("why"), Some(42)),
3058                (Some("5"), None), // No match in WHEN
3059                (None, None), // None cases are never match in CASE <expr> WHEN <value> syntax
3060                (Some("what"), Some(22)),
3061                (None, None), // None cases are never match in CASE <expr> WHEN <value> syntax
3062                (None, None), // None cases are never match in CASE <expr> WHEN <value> syntax
3063                (Some("what"), Some(22)),
3064                (Some("5"), None), // No match in WHEN
3065            ]);
3066
3067        let input_values = Arc::new(input_values) as ArrayRef;
3068
3069        // Testing without ELSE should fallback to None
3070        test_case_when_literal_lookup(
3071            Arc::clone(&input_values),
3072            &lookup_map,
3073            None,
3074            Arc::new(expected.clone()),
3075        );
3076
3077        // Testing with Else
3078        let else_value = 101;
3079
3080        // Changing each expected None to be fallback
3081        let expected_with_else = expected
3082            .iter()
3083            .map(|item| item.unwrap_or(else_value))
3084            .map(Some)
3085            .collect::<Int32Array>();
3086
3087        // Test case
3088        test_case_when_literal_lookup(
3089            input_values,
3090            &lookup_map,
3091            Some(ScalarValue::Int32(Some(else_value))),
3092            Arc::new(expected_with_else),
3093        );
3094    }
3095}