datafusion_physical_expr/expressions/
case.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use super::{Column, Literal};
19use crate::expressions::case::ResultState::{Complete, Empty, Partial};
20use crate::expressions::try_cast;
21use crate::PhysicalExpr;
22use arrow::array::*;
23use arrow::compute::kernels::zip::zip;
24use arrow::compute::{
25    is_not_null, not, nullif, prep_null_mask_filter, FilterBuilder, FilterPredicate,
26    SlicesIterator,
27};
28use arrow::datatypes::{DataType, Schema, UInt32Type, UnionMode};
29use arrow::error::ArrowError;
30use datafusion_common::cast::as_boolean_array;
31use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
32use datafusion_common::{
33    exec_err, internal_datafusion_err, internal_err, DataFusionError, HashMap, HashSet,
34    Result, ScalarValue,
35};
36use datafusion_expr::ColumnarValue;
37use datafusion_physical_expr_common::datum::compare_with_eq;
38use itertools::Itertools;
39use std::borrow::Cow;
40use std::fmt::{Debug, Formatter};
41use std::hash::Hash;
42use std::{any::Any, sync::Arc};
43
44type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
45
46#[derive(Debug, Hash, PartialEq, Eq)]
47enum EvalMethod {
48    /// CASE WHEN condition THEN result
49    ///      [WHEN ...]
50    ///      [ELSE result]
51    /// END
52    NoExpression(ProjectedCaseBody),
53    /// CASE expression
54    ///     WHEN value THEN result
55    ///     [WHEN ...]
56    ///     [ELSE result]
57    /// END
58    WithExpression(ProjectedCaseBody),
59    /// This is a specialization for a specific use case where we can take a fast path
60    /// for expressions that are infallible and can be cheaply computed for the entire
61    /// record batch rather than just for the rows where the predicate is true.
62    ///
63    /// CASE WHEN condition THEN column [ELSE NULL] END
64    InfallibleExprOrNull,
65    /// This is a specialization for a specific use case where we can take a fast path
66    /// if there is just one when/then pair and both the `then` and `else` expressions
67    /// are literal values
68    /// CASE WHEN condition THEN literal ELSE literal END
69    ScalarOrScalar,
70    /// This is a specialization for a specific use case where we can take a fast path
71    /// if there is just one when/then pair and both the `then` and `else` are expressions
72    ///
73    /// CASE WHEN condition THEN expression ELSE expression END
74    ExpressionOrExpression(ProjectedCaseBody),
75}
76
77/// The body of a CASE expression which consists of an optional base expression, the "when/then"
78/// branches and an optional "else" branch.
79#[derive(Debug, Hash, PartialEq, Eq)]
80struct CaseBody {
81    /// Optional base expression that can be compared to literal values in the "when" expressions
82    expr: Option<Arc<dyn PhysicalExpr>>,
83    /// One or more when/then expressions
84    when_then_expr: Vec<WhenThen>,
85    /// Optional "else" expression
86    else_expr: Option<Arc<dyn PhysicalExpr>>,
87}
88
89impl CaseBody {
90    /// Derives a [ProjectedCaseBody] from this [CaseBody].
91    fn project(&self) -> Result<ProjectedCaseBody> {
92        // Determine the set of columns that are used in all the expressions of the case body.
93        let mut used_column_indices = HashSet::<usize>::new();
94        let mut collect_column_indices = |expr: &Arc<dyn PhysicalExpr>| {
95            expr.apply(|expr| {
96                if let Some(column) = expr.as_any().downcast_ref::<Column>() {
97                    used_column_indices.insert(column.index());
98                }
99                Ok(TreeNodeRecursion::Continue)
100            })
101            .expect("Closure cannot fail");
102        };
103
104        if let Some(e) = &self.expr {
105            collect_column_indices(e);
106        }
107        self.when_then_expr.iter().for_each(|(w, t)| {
108            collect_column_indices(w);
109            collect_column_indices(t);
110        });
111        if let Some(e) = &self.else_expr {
112            collect_column_indices(e);
113        }
114
115        // Construct a mapping from the original column index to the projected column index.
116        let column_index_map = used_column_indices
117            .iter()
118            .enumerate()
119            .map(|(projected, original)| (*original, projected))
120            .collect::<HashMap<usize, usize>>();
121
122        // Construct the projected body by rewriting each expression from the original body
123        // using the column index mapping.
124        let project = |expr: &Arc<dyn PhysicalExpr>| -> Result<Arc<dyn PhysicalExpr>> {
125            Arc::clone(expr)
126                .transform_down(|e| {
127                    if let Some(column) = e.as_any().downcast_ref::<Column>() {
128                        let original = column.index();
129                        let projected = *column_index_map.get(&original).unwrap();
130                        if projected != original {
131                            return Ok(Transformed::yes(Arc::new(Column::new(
132                                column.name(),
133                                projected,
134                            ))));
135                        }
136                    }
137                    Ok(Transformed::no(e))
138                })
139                .map(|t| t.data)
140        };
141
142        let projected_body = CaseBody {
143            expr: self.expr.as_ref().map(project).transpose()?,
144            when_then_expr: self
145                .when_then_expr
146                .iter()
147                .map(|(e, t)| Ok((project(e)?, project(t)?)))
148                .collect::<Result<Vec<_>>>()?,
149            else_expr: self.else_expr.as_ref().map(project).transpose()?,
150        };
151
152        // Construct the projection vector
153        let projection = column_index_map
154            .iter()
155            .sorted_by_key(|(_, v)| **v)
156            .map(|(k, _)| *k)
157            .collect::<Vec<_>>();
158
159        Ok(ProjectedCaseBody {
160            projection,
161            body: projected_body,
162        })
163    }
164}
165
166/// A derived case body that can be used to evaluate a case expression after projecting
167/// record batches using a projection vector.
168///
169/// This is used to avoid filtering columns that are not used in the
170/// input `RecordBatch` when progressively evaluating a `CASE` expression's
171/// remainder batches. Filtering these columns is wasteful since for a record
172/// batch of `n` rows, filtering requires at worst a copy of `n - 1` values
173/// per array. If these filtered values will never be accessed, the time spent
174/// producing them is better avoided.
175///
176/// For example, if we are evaluating the following case expression that
177/// only references columns B and D:
178///
179/// ```sql
180/// SELECT CASE WHEN B > 10 THEN D ELSE NULL END FROM (VALUES (...)) T(A, B, C, D)
181/// ```
182///
183/// Of the 4 input columns `[A, B, C, D]`, the `CASE` expression only access `B` and `D`.
184/// Filtering `A` and `C` would be unnecessary and wasteful.
185///
186/// If we only retain columns `B` and `D` using `RecordBatch::project` and the projection vector
187/// `[1, 3]`, the indices of these two columns will change to `[0, 1]`. To evaluate the
188/// case expression, it will need to be rewritten from `CASE WHEN B@1 > 10 THEN D@3 ELSE NULL END`
189/// to `CASE WHEN B@0 > 10 THEN D@1 ELSE NULL END`.
190///
191/// The projection vector and the rewritten expression (which only differs from the original in
192/// column reference indices) are held in a `ProjectedCaseBody`.
193#[derive(Debug, Hash, PartialEq, Eq)]
194struct ProjectedCaseBody {
195    projection: Vec<usize>,
196    body: CaseBody,
197}
198
199/// The CASE expression is similar to a series of nested if/else and there are two forms that
200/// can be used. The first form consists of a series of boolean "when" expressions with
201/// corresponding "then" expressions, and an optional "else" expression.
202///
203/// CASE WHEN condition THEN result
204///      [WHEN ...]
205///      [ELSE result]
206/// END
207///
208/// The second form uses a base expression and then a series of "when" clauses that match on a
209/// literal value.
210///
211/// CASE expression
212///     WHEN value THEN result
213///     [WHEN ...]
214///     [ELSE result]
215/// END
216#[derive(Debug, Hash, PartialEq, Eq)]
217pub struct CaseExpr {
218    /// The case expression body
219    body: CaseBody,
220    /// Evaluation method to use
221    eval_method: EvalMethod,
222}
223
224impl std::fmt::Display for CaseExpr {
225    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
226        write!(f, "CASE ")?;
227        if let Some(e) = &self.body.expr {
228            write!(f, "{e} ")?;
229        }
230        for (w, t) in &self.body.when_then_expr {
231            write!(f, "WHEN {w} THEN {t} ")?;
232        }
233        if let Some(e) = &self.body.else_expr {
234            write!(f, "ELSE {e} ")?;
235        }
236        write!(f, "END")
237    }
238}
239
240/// This is a specialization for a specific use case where we can take a fast path
241/// for expressions that are infallible and can be cheaply computed for the entire
242/// record batch rather than just for the rows where the predicate is true. For now,
243/// this is limited to use with Column expressions but could potentially be used for other
244/// expressions in the future
245fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) -> bool {
246    expr.as_any().is::<Column>()
247}
248
249/// Creates a [FilterPredicate] from a boolean array.
250fn create_filter(predicate: &BooleanArray, optimize: bool) -> FilterPredicate {
251    let mut filter_builder = FilterBuilder::new(predicate);
252    if optimize {
253        // Always optimize the filter since we use them multiple times.
254        filter_builder = filter_builder.optimize();
255    }
256    filter_builder.build()
257}
258
259fn multiple_arrays(data_type: &DataType) -> bool {
260    match data_type {
261        DataType::Struct(fields) => {
262            fields.len() > 1
263                || fields.len() == 1 && multiple_arrays(fields[0].data_type())
264        }
265        DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
266        _ => false,
267    }
268}
269
270// This should be removed when https://github.com/apache/arrow-rs/pull/8693
271// is merged and becomes available.
272fn filter_record_batch(
273    record_batch: &RecordBatch,
274    filter: &FilterPredicate,
275) -> std::result::Result<RecordBatch, ArrowError> {
276    let filtered_columns = record_batch
277        .columns()
278        .iter()
279        .map(|a| filter_array(a, filter))
280        .collect::<std::result::Result<Vec<_>, _>>()?;
281    // SAFETY: since we start from a valid RecordBatch, there's no need to revalidate the schema
282    // since the set of columns has not changed.
283    // The input column arrays all had the same length (since they're coming from a valid RecordBatch)
284    // and the filtering them with the same filter will produces a new set of arrays with identical
285    // lengths.
286    unsafe {
287        Ok(RecordBatch::new_unchecked(
288            record_batch.schema(),
289            filtered_columns,
290            filter.count(),
291        ))
292    }
293}
294
295// This function exists purely to be able to use the same call style
296// for `filter_record_batch` and `filter_array` at the point of use.
297// When https://github.com/apache/arrow-rs/pull/8693 is available, replace
298// both with method calls on `FilterPredicate`.
299#[inline(always)]
300fn filter_array(
301    array: &dyn Array,
302    filter: &FilterPredicate,
303) -> std::result::Result<ArrayRef, ArrowError> {
304    filter.filter(array)
305}
306
307fn merge(
308    mask: &BooleanArray,
309    truthy: ColumnarValue,
310    falsy: ColumnarValue,
311) -> std::result::Result<ArrayRef, ArrowError> {
312    let (truthy, truthy_is_scalar) = match truthy {
313        ColumnarValue::Array(a) => (a, false),
314        ColumnarValue::Scalar(s) => (s.to_array()?, true),
315    };
316    let (falsy, falsy_is_scalar) = match falsy {
317        ColumnarValue::Array(a) => (a, false),
318        ColumnarValue::Scalar(s) => (s.to_array()?, true),
319    };
320
321    if truthy_is_scalar && falsy_is_scalar {
322        return zip(mask, &Scalar::new(truthy), &Scalar::new(falsy));
323    }
324
325    let falsy = falsy.to_data();
326    let truthy = truthy.to_data();
327
328    let mut mutable = MutableArrayData::new(vec![&truthy, &falsy], false, truthy.len());
329
330    // the SlicesIterator slices only the true values. So the gaps left by this iterator we need to
331    // fill with falsy values
332
333    // keep track of how much is filled
334    let mut filled = 0;
335    let mut falsy_offset = 0;
336    let mut truthy_offset = 0;
337
338    SlicesIterator::new(mask).for_each(|(start, end)| {
339        // the gap needs to be filled with falsy values
340        if start > filled {
341            if falsy_is_scalar {
342                for _ in filled..start {
343                    // Copy the first item from the 'falsy' array into the output buffer.
344                    mutable.extend(1, 0, 1);
345                }
346            } else {
347                let falsy_length = start - filled;
348                let falsy_end = falsy_offset + falsy_length;
349                mutable.extend(1, falsy_offset, falsy_end);
350                falsy_offset = falsy_end;
351            }
352        }
353        // fill with truthy values
354        if truthy_is_scalar {
355            for _ in start..end {
356                // Copy the first item from the 'truthy' array into the output buffer.
357                mutable.extend(0, 0, 1);
358            }
359        } else {
360            let truthy_length = end - start;
361            let truthy_end = truthy_offset + truthy_length;
362            mutable.extend(0, truthy_offset, truthy_end);
363            truthy_offset = truthy_end;
364        }
365        filled = end;
366    });
367    // the remaining part is falsy
368    if filled < mask.len() {
369        if falsy_is_scalar {
370            for _ in filled..mask.len() {
371                // Copy the first item from the 'falsy' array into the output buffer.
372                mutable.extend(1, 0, 1);
373            }
374        } else {
375            let falsy_length = mask.len() - filled;
376            let falsy_end = falsy_offset + falsy_length;
377            mutable.extend(1, falsy_offset, falsy_end);
378        }
379    }
380
381    let data = mutable.freeze();
382    Ok(make_array(data))
383}
384
385/// Merges elements by index from a list of [`ArrayData`], creating a new [`ColumnarValue`] from
386/// those values.
387///
388/// Each element in `indices` is the index of an array in `values`. The `indices` array is processed
389/// sequentially. The first occurrence of index value `n` will be mapped to the first
390/// value of the array at index `n`. The second occurrence to the second value, and so on.
391/// An index value where `PartialResultIndex::is_none` is `true` is used to indicate null values.
392///
393/// # Implementation notes
394///
395/// This algorithm is similar in nature to both `zip` and `interleave`, but there are some important
396/// differences.
397///
398/// In contrast to `zip`, this function supports multiple input arrays. Instead of a boolean
399/// selection vector, an index array is to take values from the input arrays, and a special marker
400/// value is used to indicate null values.
401///
402/// In contrast to `interleave`, this function does not use pairs of indices. The values in
403/// `indices` serve the same purpose as the first value in the pairs passed to `interleave`.
404/// The index in the array is implicit and is derived from the number of times a particular array
405/// index occurs.
406/// The more constrained indexing mechanism used by this algorithm makes it easier to copy values
407/// in contiguous slices. In the example below, the two subsequent elements from array `2` can be
408/// copied in a single operation from the source array instead of copying them one by one.
409/// Long spans of null values are also especially cheap because they do not need to be represented
410/// in an input array.
411///
412/// # Safety
413///
414/// This function does not check that the number of occurrences of any particular array index matches
415/// the length of the corresponding input array. If an array contains more values than required, the
416/// spurious values will be ignored. If an array contains fewer values than necessary, this function
417/// will panic.
418///
419/// # Example
420///
421/// ```text
422/// ┌───────────┐  ┌─────────┐                             ┌─────────┐
423/// │┌─────────┐│  │   None  │                             │   NULL  │
424/// ││    A    ││  ├─────────┤                             ├─────────┤
425/// │└─────────┘│  │    1    │                             │    B    │
426/// │┌─────────┐│  ├─────────┤                             ├─────────┤
427/// ││    B    ││  │    0    │    merge(values, indices)   │    A    │
428/// │└─────────┘│  ├─────────┤  ─────────────────────────▶ ├─────────┤
429/// │┌─────────┐│  │   None  │                             │   NULL  │
430/// ││    C    ││  ├─────────┤                             ├─────────┤
431/// │├─────────┤│  │    2    │                             │    C    │
432/// ││    D    ││  ├─────────┤                             ├─────────┤
433/// │└─────────┘│  │    2    │                             │    D    │
434/// └───────────┘  └─────────┘                             └─────────┘
435///    values        indices                                  result
436/// ```
437fn merge_n(values: &[ArrayData], indices: &[PartialResultIndex]) -> Result<ArrayRef> {
438    #[cfg(debug_assertions)]
439    for ix in indices {
440        if let Some(index) = ix.index() {
441            assert!(
442                index < values.len(),
443                "Index out of bounds: {} >= {}",
444                index,
445                values.len()
446            );
447        }
448    }
449
450    let data_refs = values.iter().collect();
451    let mut mutable = MutableArrayData::new(data_refs, true, indices.len());
452
453    // This loop extends the mutable array by taking slices from the partial results.
454    //
455    // take_offsets keeps track of how many values have been taken from each array.
456    let mut take_offsets = vec![0; values.len() + 1];
457    let mut start_row_ix = 0;
458    loop {
459        let array_ix = indices[start_row_ix];
460
461        // Determine the length of the slice to take.
462        let mut end_row_ix = start_row_ix + 1;
463        while end_row_ix < indices.len() && indices[end_row_ix] == array_ix {
464            end_row_ix += 1;
465        }
466        let slice_length = end_row_ix - start_row_ix;
467
468        // Extend mutable with either nulls or with values from the array.
469        match array_ix.index() {
470            None => mutable.extend_nulls(slice_length),
471            Some(index) => {
472                let start_offset = take_offsets[index];
473                let end_offset = start_offset + slice_length;
474                mutable.extend(index, start_offset, end_offset);
475                take_offsets[index] = end_offset;
476            }
477        }
478
479        if end_row_ix == indices.len() {
480            break;
481        } else {
482            // Set the start_row_ix for the next slice.
483            start_row_ix = end_row_ix;
484        }
485    }
486
487    Ok(make_array(mutable.freeze()))
488}
489
490/// An index into the partial results array that's more compact than `usize`.
491///
492/// `u32::MAX` is reserved as a special 'none' value. This is used instead of
493/// `Option` to keep the array of indices as compact as possible.
494#[derive(Copy, Clone, PartialEq, Eq)]
495struct PartialResultIndex {
496    index: u32,
497}
498
499const NONE_VALUE: u32 = u32::MAX;
500
501impl PartialResultIndex {
502    /// Returns the 'none' placeholder value.
503    fn none() -> Self {
504        Self { index: NONE_VALUE }
505    }
506
507    fn zero() -> Self {
508        Self { index: 0 }
509    }
510
511    /// Creates a new partial result index.
512    ///
513    /// If the provided value is greater than or equal to `u32::MAX`
514    /// an error will be returned.
515    fn try_new(index: usize) -> Result<Self> {
516        let Ok(index) = u32::try_from(index) else {
517            return internal_err!("Partial result index exceeds limit");
518        };
519
520        if index == NONE_VALUE {
521            return internal_err!("Partial result index exceeds limit");
522        }
523
524        Ok(Self { index })
525    }
526
527    /// Determines if this index is the 'none' placeholder value or not.
528    fn is_none(&self) -> bool {
529        self.index == NONE_VALUE
530    }
531
532    /// Returns `Some(index)` if this value is not the 'none' placeholder, `None` otherwise.
533    fn index(&self) -> Option<usize> {
534        if self.is_none() {
535            None
536        } else {
537            Some(self.index as usize)
538        }
539    }
540}
541
542impl Debug for PartialResultIndex {
543    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
544        if self.is_none() {
545            write!(f, "null")
546        } else {
547            write!(f, "{}", self.index)
548        }
549    }
550}
551
552enum ResultState {
553    /// The final result is an array containing only null values.
554    Empty,
555    /// The final result needs to be computed by merging the data in `arrays`.
556    Partial {
557        // A `Vec` of partial results that should be merged.
558        // `partial_result_indices` contains indexes into this vec.
559        arrays: Vec<ArrayData>,
560        // Indicates per result row from which array in `partial_results` a value should be taken.
561        indices: Vec<PartialResultIndex>,
562    },
563    /// A single branch matched all input rows. When creating the final result, no further merging
564    /// of partial results is necessary.
565    Complete(ColumnarValue),
566}
567
568/// A builder for constructing result arrays for CASE expressions.
569///
570/// Rather than building a monolithic array containing all results, it maintains a set of
571/// partial result arrays and a mapping that indicates for each row which partial array
572/// contains the result value for that row.
573///
574/// On finish(), the builder will merge all partial results into a single array if necessary.
575/// If all rows evaluated to the same array, that array can be returned directly without
576/// any merging overhead.
577struct ResultBuilder {
578    data_type: DataType,
579    /// The number of rows in the final result.
580    row_count: usize,
581    state: ResultState,
582}
583
584impl ResultBuilder {
585    /// Creates a new ResultBuilder that will produce arrays of the given data type.
586    ///
587    /// The `row_count` parameter indicates the number of rows in the final result.
588    fn new(data_type: &DataType, row_count: usize) -> Self {
589        Self {
590            data_type: data_type.clone(),
591            row_count,
592            state: Empty,
593        }
594    }
595
596    /// Adds a result for one branch of the case expression.
597    ///
598    /// `row_indices` should be a [UInt32Array] containing [RecordBatch] relative row indices
599    /// for which `value` contains result values.
600    ///
601    /// If `value` is a scalar, the scalar value will be used as the value for each row in `row_indices`.
602    ///
603    /// If `value` is an array, the values from the array and the indices from `row_indices` will be
604    /// processed pairwise. The lengths of `value` and `row_indices` must match.
605    ///
606    /// The diagram below shows a situation where a when expression matched rows 1 and 4 of the
607    /// record batch. The then expression produced the value array `[A, D]`.
608    /// After adding this result, the result array will have been added to `partial arrays` and
609    /// `partial indices` will have been updated at indexes `1` and `4`.
610    ///
611    /// ```text
612    ///  ┌─────────┐     ┌─────────┐┌───────────┐                            ┌─────────┐┌───────────┐
613    ///  │    C    │     │ 0: None ││┌ 0 ──────┐│                            │ 0: None ││┌ 0 ──────┐│
614    ///  ├─────────┤     ├─────────┤││    A    ││                            ├─────────┤││    A    ││
615    ///  │    D    │     │ 1: None ││└─────────┘│                            │ 1:  2   ││└─────────┘│
616    ///  └─────────┘     ├─────────┤│┌ 1 ──────┐│   add_branch_result(       ├─────────┤│┌ 1 ──────┐│
617    ///   matching       │ 2:  0   │││    B    ││     row indices,           │ 2:  0   │││    B    ││
618    /// 'then' values    ├─────────┤│└─────────┘│     value                  ├─────────┤│└─────────┘│
619    ///                  │ 3: None ││           │   )                        │ 3: None ││┌ 2 ──────┐│
620    ///  ┌─────────┐     ├─────────┤│           │ ─────────────────────────▶ ├─────────┤││    C    ││
621    ///  │    1    │     │ 4: None ││           │                            │ 4:  2   ││├─────────┤│
622    ///  ├─────────┤     ├─────────┤│           │                            ├─────────┤││    D    ││
623    ///  │    4    │     │ 5:  1   ││           │                            │ 5:  1   ││└─────────┘│
624    ///  └─────────┘     └─────────┘└───────────┘                            └─────────┘└───────────┘
625    /// row indices        partial     partial                                 partial     partial
626    ///                    indices     arrays                                  indices     arrays
627    /// ```
628    fn add_branch_result(
629        &mut self,
630        row_indices: &ArrayRef,
631        value: ColumnarValue,
632    ) -> Result<()> {
633        match value {
634            ColumnarValue::Array(a) => {
635                if a.len() != row_indices.len() {
636                    internal_err!("Array length must match row indices length")
637                } else if row_indices.len() == self.row_count {
638                    self.set_complete_result(ColumnarValue::Array(a))
639                } else {
640                    self.add_partial_result(row_indices, a.to_data())
641                }
642            }
643            ColumnarValue::Scalar(s) => {
644                if row_indices.len() == self.row_count {
645                    self.set_complete_result(ColumnarValue::Scalar(s))
646                } else {
647                    self.add_partial_result(
648                        row_indices,
649                        s.to_array_of_size(row_indices.len())?.to_data(),
650                    )
651                }
652            }
653        }
654    }
655
656    /// Adds a partial result array.
657    ///
658    /// This method adds the given array data as a partial result and updates the index mapping
659    /// to indicate that the specified rows should take their values from this array.
660    /// The partial results will be merged into a single array when finish() is called.
661    fn add_partial_result(
662        &mut self,
663        row_indices: &ArrayRef,
664        row_values: ArrayData,
665    ) -> Result<()> {
666        if row_indices.null_count() != 0 {
667            return internal_err!("Row indices must not contain nulls");
668        }
669
670        match &mut self.state {
671            Empty => {
672                let array_index = PartialResultIndex::zero();
673                let mut indices = vec![PartialResultIndex::none(); self.row_count];
674                for row_ix in row_indices.as_primitive::<UInt32Type>().values().iter() {
675                    indices[*row_ix as usize] = array_index;
676                }
677
678                self.state = Partial {
679                    arrays: vec![row_values],
680                    indices,
681                };
682
683                Ok(())
684            }
685            Partial { arrays, indices } => {
686                let array_index = PartialResultIndex::try_new(arrays.len())?;
687
688                arrays.push(row_values);
689
690                for row_ix in row_indices.as_primitive::<UInt32Type>().values().iter() {
691                    // This is check is only active for debug config because the callers of this method,
692                    // `case_when_with_expr` and `case_when_no_expr`, already ensure that
693                    // they only calculate a value for each row at most once.
694                    #[cfg(debug_assertions)]
695                    if !indices[*row_ix as usize].is_none() {
696                        return internal_err!("Duplicate value for row {}", *row_ix);
697                    }
698
699                    indices[*row_ix as usize] = array_index;
700                }
701                Ok(())
702            }
703            Complete(_) => internal_err!(
704                "Cannot add a partial result when complete result is already set"
705            ),
706        }
707    }
708
709    /// Sets a result that applies to all rows.
710    ///
711    /// This is an optimization for cases where all rows evaluate to the same result.
712    /// When a complete result is set, the builder will return it directly from finish()
713    /// without any merging overhead.
714    fn set_complete_result(&mut self, value: ColumnarValue) -> Result<()> {
715        match &self.state {
716            Empty => {
717                self.state = Complete(value);
718                Ok(())
719            }
720            Partial { .. } => {
721                internal_err!(
722                    "Cannot set a complete result when there are already partial results"
723                )
724            }
725            Complete(_) => internal_err!("Complete result already set"),
726        }
727    }
728
729    /// Finishes building the result and returns the final array.
730    fn finish(self) -> Result<ColumnarValue> {
731        match self.state {
732            Empty => {
733                // No complete result and no partial results.
734                // This can happen for case expressions with no else branch where no rows
735                // matched.
736                Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
737                    &self.data_type,
738                )?))
739            }
740            Partial { arrays, indices } => {
741                // Merge partial results into a single array.
742                Ok(ColumnarValue::Array(merge_n(&arrays, &indices)?))
743            }
744            Complete(v) => {
745                // If we have a complete result, we can just return it.
746                Ok(v)
747            }
748        }
749    }
750}
751
752impl CaseExpr {
753    /// Create a new CASE WHEN expression
754    pub fn try_new(
755        expr: Option<Arc<dyn PhysicalExpr>>,
756        when_then_expr: Vec<WhenThen>,
757        else_expr: Option<Arc<dyn PhysicalExpr>>,
758    ) -> Result<Self> {
759        // normalize null literals to None in the else_expr (this already happens
760        // during SQL planning, but not necessarily for other use cases)
761        let else_expr = match &else_expr {
762            Some(e) => match e.as_any().downcast_ref::<Literal>() {
763                Some(lit) if lit.value().is_null() => None,
764                _ => else_expr,
765            },
766            _ => else_expr,
767        };
768
769        if when_then_expr.is_empty() {
770            return exec_err!("There must be at least one WHEN clause");
771        }
772
773        let body = CaseBody {
774            expr,
775            when_then_expr,
776            else_expr,
777        };
778
779        let eval_method = if body.expr.is_some() {
780            EvalMethod::WithExpression(body.project()?)
781        } else if body.when_then_expr.len() == 1
782            && is_cheap_and_infallible(&(body.when_then_expr[0].1))
783            && body.else_expr.is_none()
784        {
785            EvalMethod::InfallibleExprOrNull
786        } else if body.when_then_expr.len() == 1
787            && body.when_then_expr[0].1.as_any().is::<Literal>()
788            && body.else_expr.is_some()
789            && body.else_expr.as_ref().unwrap().as_any().is::<Literal>()
790        {
791            EvalMethod::ScalarOrScalar
792        } else if body.when_then_expr.len() == 1 && body.else_expr.is_some() {
793            EvalMethod::ExpressionOrExpression(body.project()?)
794        } else {
795            EvalMethod::NoExpression(body.project()?)
796        };
797
798        Ok(Self { body, eval_method })
799    }
800
801    /// Optional base expression that can be compared to literal values in the "when" expressions
802    pub fn expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
803        self.body.expr.as_ref()
804    }
805
806    /// One or more when/then expressions
807    pub fn when_then_expr(&self) -> &[WhenThen] {
808        &self.body.when_then_expr
809    }
810
811    /// Optional "else" expression
812    pub fn else_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
813        self.body.else_expr.as_ref()
814    }
815}
816
817impl CaseBody {
818    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
819        // since all then results have the same data type, we can choose any one as the
820        // return data type except for the null.
821        let mut data_type = DataType::Null;
822        for i in 0..self.when_then_expr.len() {
823            data_type = self.when_then_expr[i].1.data_type(input_schema)?;
824            if !data_type.equals_datatype(&DataType::Null) {
825                break;
826            }
827        }
828        // if all then results are null, we use data type of else expr instead if possible.
829        if data_type.equals_datatype(&DataType::Null) {
830            if let Some(e) = &self.else_expr {
831                data_type = e.data_type(input_schema)?;
832            }
833        }
834
835        Ok(data_type)
836    }
837
838    /// See [CaseExpr::case_when_with_expr].
839    fn case_when_with_expr(
840        &self,
841        batch: &RecordBatch,
842        return_type: &DataType,
843    ) -> Result<ColumnarValue> {
844        let mut result_builder = ResultBuilder::new(return_type, batch.num_rows());
845
846        // `remainder_rows` contains the indices of the rows that need to be evaluated
847        let mut remainder_rows: ArrayRef =
848            Arc::new(UInt32Array::from_iter_values(0..batch.num_rows() as u32));
849        // `remainder_batch` contains the rows themselves that need to be evaluated
850        let mut remainder_batch = Cow::Borrowed(batch);
851
852        // evaluate the base expression
853        let mut base_values = self
854            .expr
855            .as_ref()
856            .unwrap()
857            .evaluate(batch)?
858            .into_array(batch.num_rows())?;
859
860        // Fill in a result value already for rows where the base expression value is null
861        // Since each when expression is tested against the base expression using the equality
862        // operator, null base values can never match any when expression. `x = NULL` is falsy,
863        // for all possible values of `x`.
864        if base_values.null_count() > 0 {
865            // Use `is_not_null` since this is a cheap clone of the null buffer from 'base_value'.
866            // We already checked there are nulls, so we can be sure a new buffer will not be
867            // created.
868            let base_not_nulls = is_not_null(base_values.as_ref())?;
869            let base_all_null = base_values.null_count() == remainder_batch.num_rows();
870
871            // If there is an else expression, use that as the default value for the null rows
872            // Otherwise the default `null` value from the result builder will be used.
873            if let Some(e) = &self.else_expr {
874                let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
875
876                if base_all_null {
877                    // All base values were null, so no need to filter
878                    let nulls_value = expr.evaluate(&remainder_batch)?;
879                    result_builder.add_branch_result(&remainder_rows, nulls_value)?;
880                } else {
881                    // Filter out the null rows and evaluate the else expression for those
882                    let nulls_filter = create_filter(&not(&base_not_nulls)?, true);
883                    let nulls_batch =
884                        filter_record_batch(&remainder_batch, &nulls_filter)?;
885                    let nulls_rows = filter_array(&remainder_rows, &nulls_filter)?;
886                    let nulls_value = expr.evaluate(&nulls_batch)?;
887                    result_builder.add_branch_result(&nulls_rows, nulls_value)?;
888                }
889            }
890
891            // All base values are null, so we can return early
892            if base_all_null {
893                return result_builder.finish();
894            }
895
896            // Remove the null rows from the remainder batch
897            let not_null_filter = create_filter(&base_not_nulls, true);
898            remainder_batch =
899                Cow::Owned(filter_record_batch(&remainder_batch, &not_null_filter)?);
900            remainder_rows = filter_array(&remainder_rows, &not_null_filter)?;
901            base_values = filter_array(&base_values, &not_null_filter)?;
902        }
903
904        // The types of case and when expressions will be coerced to match.
905        // We only need to check if the base_value is nested.
906        let base_value_is_nested = base_values.data_type().is_nested();
907
908        for i in 0..self.when_then_expr.len() {
909            // Evaluate the 'when' predicate for the remainder batch
910            // This results in a boolean array with the same length as the remaining number of rows
911            let when_expr = &self.when_then_expr[i].0;
912            let when_value = match when_expr.evaluate(&remainder_batch)? {
913                ColumnarValue::Array(a) => {
914                    compare_with_eq(&a, &base_values, base_value_is_nested)
915                }
916                ColumnarValue::Scalar(s) => {
917                    compare_with_eq(&s.to_scalar()?, &base_values, base_value_is_nested)
918                }
919            }?;
920
921            // `true_count` ignores `true` values where the validity bit is not set, so there's
922            // no need to call `prep_null_mask_filter`.
923            let when_true_count = when_value.true_count();
924
925            // If the 'when' predicate did not match any rows, continue to the next branch immediately
926            if when_true_count == 0 {
927                continue;
928            }
929
930            // If the 'when' predicate matched all remaining rows, there is no need to filter
931            if when_true_count == remainder_batch.num_rows() {
932                let then_expression = &self.when_then_expr[i].1;
933                let then_value = then_expression.evaluate(&remainder_batch)?;
934                result_builder.add_branch_result(&remainder_rows, then_value)?;
935                return result_builder.finish();
936            }
937
938            // Filter the remainder batch based on the 'when' value
939            // This results in a batch containing only the rows that need to be evaluated
940            // for the current branch
941            // Still no need to call `prep_null_mask_filter` since `create_filter` will already do
942            // this unconditionally.
943            let then_filter = create_filter(&when_value, true);
944            let then_batch = filter_record_batch(&remainder_batch, &then_filter)?;
945            let then_rows = filter_array(&remainder_rows, &then_filter)?;
946
947            let then_expression = &self.when_then_expr[i].1;
948            let then_value = then_expression.evaluate(&then_batch)?;
949            result_builder.add_branch_result(&then_rows, then_value)?;
950
951            // If this is the last 'when' branch and there is no 'else' expression, there's no
952            // point in calculating the remaining rows.
953            if self.else_expr.is_none() && i == self.when_then_expr.len() - 1 {
954                return result_builder.finish();
955            }
956
957            // Prepare the next when branch (or the else branch)
958            let next_selection = match when_value.null_count() {
959                0 => not(&when_value),
960                _ => {
961                    // `prep_null_mask_filter` is required to ensure the not operation treats nulls
962                    // as false
963                    not(&prep_null_mask_filter(&when_value))
964                }
965            }?;
966            let next_filter = create_filter(&next_selection, true);
967            remainder_batch =
968                Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?);
969            remainder_rows = filter_array(&remainder_rows, &next_filter)?;
970            base_values = filter_array(&base_values, &next_filter)?;
971        }
972
973        // If we reached this point, some rows were left unmatched.
974        // Check if those need to be evaluated using the 'else' expression.
975        if let Some(e) = &self.else_expr {
976            // keep `else_expr`'s data type and return type consistent
977            let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
978            let else_value = expr.evaluate(&remainder_batch)?;
979            result_builder.add_branch_result(&remainder_rows, else_value)?;
980        }
981
982        result_builder.finish()
983    }
984
985    /// See [CaseExpr::case_when_no_expr].
986    fn case_when_no_expr(
987        &self,
988        batch: &RecordBatch,
989        return_type: &DataType,
990    ) -> Result<ColumnarValue> {
991        let mut result_builder = ResultBuilder::new(return_type, batch.num_rows());
992
993        // `remainder_rows` contains the indices of the rows that need to be evaluated
994        let mut remainder_rows: ArrayRef =
995            Arc::new(UInt32Array::from_iter(0..batch.num_rows() as u32));
996        // `remainder_batch` contains the rows themselves that need to be evaluated
997        let mut remainder_batch = Cow::Borrowed(batch);
998
999        for i in 0..self.when_then_expr.len() {
1000            // Evaluate the 'when' predicate for the remainder batch
1001            // This results in a boolean array with the same length as the remaining number of rows
1002            let when_predicate = &self.when_then_expr[i].0;
1003            let when_value = when_predicate
1004                .evaluate(&remainder_batch)?
1005                .into_array(remainder_batch.num_rows())?;
1006            let when_value = as_boolean_array(&when_value).map_err(|_| {
1007                internal_datafusion_err!("WHEN expression did not return a BooleanArray")
1008            })?;
1009
1010            // `true_count` ignores `true` values where the validity bit is not set, so there's
1011            // no need to call `prep_null_mask_filter`.
1012            let when_true_count = when_value.true_count();
1013
1014            // If the 'when' predicate did not match any rows, continue to the next branch immediately
1015            if when_true_count == 0 {
1016                continue;
1017            }
1018
1019            // If the 'when' predicate matched all remaining rows, there is no need to filter
1020            if when_true_count == remainder_batch.num_rows() {
1021                let then_expression = &self.when_then_expr[i].1;
1022                let then_value = then_expression.evaluate(&remainder_batch)?;
1023                result_builder.add_branch_result(&remainder_rows, then_value)?;
1024                return result_builder.finish();
1025            }
1026
1027            // Filter the remainder batch based on the 'when' value
1028            // This results in a batch containing only the rows that need to be evaluated
1029            // for the current branch
1030            // Still no need to call `prep_null_mask_filter` since `create_filter` will already do
1031            // this unconditionally.
1032            let then_filter = create_filter(when_value, true);
1033            let then_batch = filter_record_batch(&remainder_batch, &then_filter)?;
1034            let then_rows = filter_array(&remainder_rows, &then_filter)?;
1035
1036            let then_expression = &self.when_then_expr[i].1;
1037            let then_value = then_expression.evaluate(&then_batch)?;
1038            result_builder.add_branch_result(&then_rows, then_value)?;
1039
1040            // If this is the last 'when' branch and there is no 'else' expression, there's no
1041            // point in calculating the remaining rows.
1042            if self.else_expr.is_none() && i == self.when_then_expr.len() - 1 {
1043                return result_builder.finish();
1044            }
1045
1046            // Prepare the next when branch (or the else branch)
1047            let next_selection = match when_value.null_count() {
1048                0 => not(when_value),
1049                _ => {
1050                    // `prep_null_mask_filter` is required to ensure the not operation treats nulls
1051                    // as false
1052                    not(&prep_null_mask_filter(when_value))
1053                }
1054            }?;
1055            let next_filter = create_filter(&next_selection, true);
1056            remainder_batch =
1057                Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?);
1058            remainder_rows = filter_array(&remainder_rows, &next_filter)?;
1059        }
1060
1061        // If we reached this point, some rows were left unmatched.
1062        // Check if those need to be evaluated using the 'else' expression.
1063        if let Some(e) = &self.else_expr {
1064            // keep `else_expr`'s data type and return type consistent
1065            let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
1066            let else_value = expr.evaluate(&remainder_batch)?;
1067            result_builder.add_branch_result(&remainder_rows, else_value)?;
1068        }
1069
1070        result_builder.finish()
1071    }
1072
1073    /// See [CaseExpr::expr_or_expr].
1074    fn expr_or_expr(
1075        &self,
1076        batch: &RecordBatch,
1077        when_value: &BooleanArray,
1078    ) -> Result<ColumnarValue> {
1079        let when_value = match when_value.null_count() {
1080            0 => Cow::Borrowed(when_value),
1081            _ => {
1082                // `prep_null_mask_filter` is required to ensure null is treated as false
1083                Cow::Owned(prep_null_mask_filter(when_value))
1084            }
1085        };
1086
1087        let optimize_filter = batch.num_columns() > 1
1088            || (batch.num_columns() == 1 && multiple_arrays(batch.column(0).data_type()));
1089
1090        let when_filter = create_filter(&when_value, optimize_filter);
1091        let then_batch = filter_record_batch(batch, &when_filter)?;
1092        let then_value = self.when_then_expr[0].1.evaluate(&then_batch)?;
1093
1094        let else_selection = not(&when_value)?;
1095        let else_filter = create_filter(&else_selection, optimize_filter);
1096        let else_batch = filter_record_batch(batch, &else_filter)?;
1097
1098        // keep `else_expr`'s data type and return type consistent
1099        let e = self.else_expr.as_ref().unwrap();
1100        let return_type = self.data_type(&batch.schema())?;
1101        let else_expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())
1102            .unwrap_or_else(|_| Arc::clone(e));
1103
1104        let else_value = else_expr.evaluate(&else_batch)?;
1105
1106        Ok(ColumnarValue::Array(merge(
1107            &when_value,
1108            then_value,
1109            else_value,
1110        )?))
1111    }
1112}
1113
1114impl CaseExpr {
1115    /// This function evaluates the form of CASE that matches an expression to fixed values.
1116    ///
1117    /// CASE expression
1118    ///     WHEN value THEN result
1119    ///     [WHEN ...]
1120    ///     [ELSE result]
1121    /// END
1122    fn case_when_with_expr(
1123        &self,
1124        batch: &RecordBatch,
1125        projected: &ProjectedCaseBody,
1126    ) -> Result<ColumnarValue> {
1127        let return_type = self.data_type(&batch.schema())?;
1128        if projected.projection.len() < batch.num_columns() {
1129            let projected_batch = batch.project(&projected.projection)?;
1130            projected
1131                .body
1132                .case_when_with_expr(&projected_batch, &return_type)
1133        } else {
1134            self.body.case_when_with_expr(batch, &return_type)
1135        }
1136    }
1137
1138    /// This function evaluates the form of CASE where each WHEN expression is a boolean
1139    /// expression.
1140    ///
1141    /// CASE WHEN condition THEN result
1142    ///      [WHEN ...]
1143    ///      [ELSE result]
1144    /// END
1145    fn case_when_no_expr(
1146        &self,
1147        batch: &RecordBatch,
1148        projected: &ProjectedCaseBody,
1149    ) -> Result<ColumnarValue> {
1150        let return_type = self.data_type(&batch.schema())?;
1151        if projected.projection.len() < batch.num_columns() {
1152            let projected_batch = batch.project(&projected.projection)?;
1153            projected
1154                .body
1155                .case_when_no_expr(&projected_batch, &return_type)
1156        } else {
1157            self.body.case_when_no_expr(batch, &return_type)
1158        }
1159    }
1160
1161    /// This function evaluates the specialized case of:
1162    ///
1163    /// CASE WHEN condition THEN column
1164    ///      [ELSE NULL]
1165    /// END
1166    ///
1167    /// Note that this function is only safe to use for "then" expressions
1168    /// that are infallible because the expression will be evaluated for all
1169    /// rows in the input batch.
1170    fn case_column_or_null(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
1171        let when_expr = &self.body.when_then_expr[0].0;
1172        let then_expr = &self.body.when_then_expr[0].1;
1173
1174        match when_expr.evaluate(batch)? {
1175            // WHEN true --> column
1176            ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))) => {
1177                then_expr.evaluate(batch)
1178            }
1179            // WHEN [false | null] --> NULL
1180            ColumnarValue::Scalar(_) => {
1181                // return scalar NULL value
1182                ScalarValue::try_from(self.data_type(&batch.schema())?)
1183                    .map(ColumnarValue::Scalar)
1184            }
1185            // WHEN column --> column
1186            ColumnarValue::Array(bit_mask) => {
1187                let bit_mask = bit_mask
1188                    .as_any()
1189                    .downcast_ref::<BooleanArray>()
1190                    .expect("predicate should evaluate to a boolean array");
1191                // invert the bitmask
1192                let bit_mask = match bit_mask.null_count() {
1193                    0 => not(bit_mask)?,
1194                    _ => not(&prep_null_mask_filter(bit_mask))?,
1195                };
1196                match then_expr.evaluate(batch)? {
1197                    ColumnarValue::Array(array) => {
1198                        Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?))
1199                    }
1200                    ColumnarValue::Scalar(_) => {
1201                        internal_err!("expression did not evaluate to an array")
1202                    }
1203                }
1204            }
1205        }
1206    }
1207
1208    fn scalar_or_scalar(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
1209        let return_type = self.data_type(&batch.schema())?;
1210
1211        // evaluate when expression
1212        let when_value = self.body.when_then_expr[0].0.evaluate(batch)?;
1213        let when_value = when_value.into_array(batch.num_rows())?;
1214        let when_value = as_boolean_array(&when_value).map_err(|_| {
1215            internal_datafusion_err!("WHEN expression did not return a BooleanArray")
1216        })?;
1217
1218        // Treat 'NULL' as false value
1219        let when_value = match when_value.null_count() {
1220            0 => Cow::Borrowed(when_value),
1221            _ => Cow::Owned(prep_null_mask_filter(when_value)),
1222        };
1223
1224        // evaluate then_value
1225        let then_value = self.body.when_then_expr[0].1.evaluate(batch)?;
1226        let then_value = Scalar::new(then_value.into_array(1)?);
1227
1228        let Some(e) = &self.body.else_expr else {
1229            return internal_err!("expression did not evaluate to an array");
1230        };
1231        // keep `else_expr`'s data type and return type consistent
1232        let expr = try_cast(Arc::clone(e), &batch.schema(), return_type)?;
1233        let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?);
1234        Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?))
1235    }
1236
1237    fn expr_or_expr(
1238        &self,
1239        batch: &RecordBatch,
1240        projected: &ProjectedCaseBody,
1241    ) -> Result<ColumnarValue> {
1242        // evaluate when condition on batch
1243        let when_value = self.body.when_then_expr[0].0.evaluate(batch)?;
1244        // `num_rows == 1` is intentional to avoid expanding scalars.
1245        // If the `when_value` is effectively a scalar, the 'all true' and 'all false' checks
1246        // below will avoid incorrectly using the scalar as a merge/zip mask.
1247        let when_value = when_value.into_array(1)?;
1248        let when_value = as_boolean_array(&when_value).map_err(|e| {
1249            DataFusionError::Context(
1250                "WHEN expression did not return a BooleanArray".to_string(),
1251                Box::new(e),
1252            )
1253        })?;
1254
1255        let true_count = when_value.true_count();
1256        if true_count == when_value.len() {
1257            // All input rows are true, just call the 'then' expression
1258            self.body.when_then_expr[0].1.evaluate(batch)
1259        } else if true_count == 0 {
1260            // All input rows are false/null, just call the 'else' expression
1261            self.body.else_expr.as_ref().unwrap().evaluate(batch)
1262        } else if projected.projection.len() < batch.num_columns() {
1263            // The case expressions do not use all the columns of the input batch.
1264            // Project first to reduce time spent filtering.
1265            let projected_batch = batch.project(&projected.projection)?;
1266            projected.body.expr_or_expr(&projected_batch, when_value)
1267        } else {
1268            // All columns are used in the case expressions, so there is no need to project.
1269            self.body.expr_or_expr(batch, when_value)
1270        }
1271    }
1272}
1273
1274impl PhysicalExpr for CaseExpr {
1275    /// Return a reference to Any that can be used for down-casting
1276    fn as_any(&self) -> &dyn Any {
1277        self
1278    }
1279
1280    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
1281        self.body.data_type(input_schema)
1282    }
1283
1284    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
1285        // this expression is nullable if any of the input expressions are nullable
1286        let then_nullable = self
1287            .body
1288            .when_then_expr
1289            .iter()
1290            .map(|(_, t)| t.nullable(input_schema))
1291            .collect::<Result<Vec<_>>>()?;
1292        if then_nullable.contains(&true) {
1293            Ok(true)
1294        } else if let Some(e) = &self.body.else_expr {
1295            e.nullable(input_schema)
1296        } else {
1297            // CASE produces NULL if there is no `else` expr
1298            // (aka when none of the `when_then_exprs` match)
1299            Ok(true)
1300        }
1301    }
1302
1303    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
1304        match &self.eval_method {
1305            EvalMethod::WithExpression(p) => {
1306                // this use case evaluates "expr" and then compares the values with the "when"
1307                // values
1308                self.case_when_with_expr(batch, p)
1309            }
1310            EvalMethod::NoExpression(p) => {
1311                // The "when" conditions all evaluate to boolean in this use case and can be
1312                // arbitrary expressions
1313                self.case_when_no_expr(batch, p)
1314            }
1315            EvalMethod::InfallibleExprOrNull => {
1316                // Specialization for CASE WHEN expr THEN column [ELSE NULL] END
1317                self.case_column_or_null(batch)
1318            }
1319            EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch),
1320            EvalMethod::ExpressionOrExpression(p) => self.expr_or_expr(batch, p),
1321        }
1322    }
1323
1324    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
1325        let mut children = vec![];
1326        if let Some(expr) = &self.body.expr {
1327            children.push(expr)
1328        }
1329        self.body.when_then_expr.iter().for_each(|(cond, value)| {
1330            children.push(cond);
1331            children.push(value);
1332        });
1333
1334        if let Some(else_expr) = &self.body.else_expr {
1335            children.push(else_expr)
1336        }
1337        children
1338    }
1339
1340    // For physical CaseExpr, we do not allow modifying children size
1341    fn with_new_children(
1342        self: Arc<Self>,
1343        children: Vec<Arc<dyn PhysicalExpr>>,
1344    ) -> Result<Arc<dyn PhysicalExpr>> {
1345        if children.len() != self.children().len() {
1346            internal_err!("CaseExpr: Wrong number of children")
1347        } else {
1348            let (expr, when_then_expr, else_expr) =
1349                match (self.expr().is_some(), self.body.else_expr.is_some()) {
1350                    (true, true) => (
1351                        Some(&children[0]),
1352                        &children[1..children.len() - 1],
1353                        Some(&children[children.len() - 1]),
1354                    ),
1355                    (true, false) => {
1356                        (Some(&children[0]), &children[1..children.len()], None)
1357                    }
1358                    (false, true) => (
1359                        None,
1360                        &children[0..children.len() - 1],
1361                        Some(&children[children.len() - 1]),
1362                    ),
1363                    (false, false) => (None, &children[0..children.len()], None),
1364                };
1365            Ok(Arc::new(CaseExpr::try_new(
1366                expr.cloned(),
1367                when_then_expr.iter().cloned().tuples().collect(),
1368                else_expr.cloned(),
1369            )?))
1370        }
1371    }
1372
1373    fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1374        write!(f, "CASE ")?;
1375        if let Some(e) = &self.body.expr {
1376            e.fmt_sql(f)?;
1377            write!(f, " ")?;
1378        }
1379
1380        for (w, t) in &self.body.when_then_expr {
1381            write!(f, "WHEN ")?;
1382            w.fmt_sql(f)?;
1383            write!(f, " THEN ")?;
1384            t.fmt_sql(f)?;
1385            write!(f, " ")?;
1386        }
1387
1388        if let Some(e) = &self.body.else_expr {
1389            write!(f, "ELSE ")?;
1390            e.fmt_sql(f)?;
1391            write!(f, " ")?;
1392        }
1393        write!(f, "END")
1394    }
1395}
1396
1397/// Create a CASE expression
1398pub fn case(
1399    expr: Option<Arc<dyn PhysicalExpr>>,
1400    when_thens: Vec<WhenThen>,
1401    else_expr: Option<Arc<dyn PhysicalExpr>>,
1402) -> Result<Arc<dyn PhysicalExpr>> {
1403    Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?))
1404}
1405
1406#[cfg(test)]
1407mod tests {
1408    use super::*;
1409
1410    use crate::expressions::{binary, cast, col, lit, BinaryExpr};
1411    use arrow::buffer::Buffer;
1412    use arrow::datatypes::DataType::Float64;
1413    use arrow::datatypes::Field;
1414    use datafusion_common::cast::{as_float64_array, as_int32_array};
1415    use datafusion_common::plan_err;
1416    use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
1417    use datafusion_expr::type_coercion::binary::comparison_coercion;
1418    use datafusion_expr::Operator;
1419    use datafusion_physical_expr_common::physical_expr::fmt_sql;
1420
1421    #[test]
1422    fn case_with_expr() -> Result<()> {
1423        let batch = case_test_batch()?;
1424        let schema = batch.schema();
1425
1426        // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END
1427        let when1 = lit("foo");
1428        let then1 = lit(123i32);
1429        let when2 = lit("bar");
1430        let then2 = lit(456i32);
1431
1432        let expr = generate_case_when_with_type_coercion(
1433            Some(col("a", &schema)?),
1434            vec![(when1, then1), (when2, then2)],
1435            None,
1436            schema.as_ref(),
1437        )?;
1438        let result = expr
1439            .evaluate(&batch)?
1440            .into_array(batch.num_rows())
1441            .expect("Failed to convert to array");
1442        let result = as_int32_array(&result)?;
1443
1444        let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1445
1446        assert_eq!(expected, result);
1447
1448        Ok(())
1449    }
1450
1451    #[test]
1452    fn case_with_expr_else() -> Result<()> {
1453        let batch = case_test_batch()?;
1454        let schema = batch.schema();
1455
1456        // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 ELSE 999 END
1457        let when1 = lit("foo");
1458        let then1 = lit(123i32);
1459        let when2 = lit("bar");
1460        let then2 = lit(456i32);
1461        let else_value = lit(999i32);
1462
1463        let expr = generate_case_when_with_type_coercion(
1464            Some(col("a", &schema)?),
1465            vec![(when1, then1), (when2, then2)],
1466            Some(else_value),
1467            schema.as_ref(),
1468        )?;
1469        let result = expr
1470            .evaluate(&batch)?
1471            .into_array(batch.num_rows())
1472            .expect("Failed to convert to array");
1473        let result = as_int32_array(&result)?;
1474
1475        let expected =
1476            &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
1477
1478        assert_eq!(expected, result);
1479
1480        Ok(())
1481    }
1482
1483    #[test]
1484    fn case_with_expr_divide_by_zero() -> Result<()> {
1485        let batch = case_test_batch1()?;
1486        let schema = batch.schema();
1487
1488        // CASE a when 0 THEN float64(null) ELSE 25.0 / cast(a, float64)  END
1489        let when1 = lit(0i32);
1490        let then1 = lit(ScalarValue::Float64(None));
1491        let else_value = binary(
1492            lit(25.0f64),
1493            Operator::Divide,
1494            cast(col("a", &schema)?, &batch.schema(), Float64)?,
1495            &batch.schema(),
1496        )?;
1497
1498        let expr = generate_case_when_with_type_coercion(
1499            Some(col("a", &schema)?),
1500            vec![(when1, then1)],
1501            Some(else_value),
1502            schema.as_ref(),
1503        )?;
1504        let result = expr
1505            .evaluate(&batch)?
1506            .into_array(batch.num_rows())
1507            .expect("Failed to convert to array");
1508        let result =
1509            as_float64_array(&result).expect("failed to downcast to Float64Array");
1510
1511        let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
1512
1513        assert_eq!(expected, result);
1514
1515        Ok(())
1516    }
1517
1518    #[test]
1519    fn case_without_expr() -> Result<()> {
1520        let batch = case_test_batch()?;
1521        let schema = batch.schema();
1522
1523        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 END
1524        let when1 = binary(
1525            col("a", &schema)?,
1526            Operator::Eq,
1527            lit("foo"),
1528            &batch.schema(),
1529        )?;
1530        let then1 = lit(123i32);
1531        let when2 = binary(
1532            col("a", &schema)?,
1533            Operator::Eq,
1534            lit("bar"),
1535            &batch.schema(),
1536        )?;
1537        let then2 = lit(456i32);
1538
1539        let expr = generate_case_when_with_type_coercion(
1540            None,
1541            vec![(when1, then1), (when2, then2)],
1542            None,
1543            schema.as_ref(),
1544        )?;
1545        let result = expr
1546            .evaluate(&batch)?
1547            .into_array(batch.num_rows())
1548            .expect("Failed to convert to array");
1549        let result = as_int32_array(&result)?;
1550
1551        let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
1552
1553        assert_eq!(expected, result);
1554
1555        Ok(())
1556    }
1557
1558    #[test]
1559    fn case_with_expr_when_null() -> Result<()> {
1560        let batch = case_test_batch()?;
1561        let schema = batch.schema();
1562
1563        // CASE a WHEN NULL THEN 0 WHEN a THEN 123 ELSE 999 END
1564        let when1 = lit(ScalarValue::Utf8(None));
1565        let then1 = lit(0i32);
1566        let when2 = col("a", &schema)?;
1567        let then2 = lit(123i32);
1568        let else_value = lit(999i32);
1569
1570        let expr = generate_case_when_with_type_coercion(
1571            Some(col("a", &schema)?),
1572            vec![(when1, then1), (when2, then2)],
1573            Some(else_value),
1574            schema.as_ref(),
1575        )?;
1576        let result = expr
1577            .evaluate(&batch)?
1578            .into_array(batch.num_rows())
1579            .expect("Failed to convert to array");
1580        let result = as_int32_array(&result)?;
1581
1582        let expected =
1583            &Int32Array::from(vec![Some(123), Some(123), Some(999), Some(123)]);
1584
1585        assert_eq!(expected, result);
1586
1587        Ok(())
1588    }
1589
1590    #[test]
1591    fn case_without_expr_divide_by_zero() -> Result<()> {
1592        let batch = case_test_batch1()?;
1593        let schema = batch.schema();
1594
1595        // CASE WHEN a > 0 THEN 25.0 / cast(a, float64) ELSE float64(null) END
1596        let when1 = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &batch.schema())?;
1597        let then1 = binary(
1598            lit(25.0f64),
1599            Operator::Divide,
1600            cast(col("a", &schema)?, &batch.schema(), Float64)?,
1601            &batch.schema(),
1602        )?;
1603        let x = lit(ScalarValue::Float64(None));
1604
1605        let expr = generate_case_when_with_type_coercion(
1606            None,
1607            vec![(when1, then1)],
1608            Some(x),
1609            schema.as_ref(),
1610        )?;
1611        let result = expr
1612            .evaluate(&batch)?
1613            .into_array(batch.num_rows())
1614            .expect("Failed to convert to array");
1615        let result =
1616            as_float64_array(&result).expect("failed to downcast to Float64Array");
1617
1618        let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
1619
1620        assert_eq!(expected, result);
1621
1622        Ok(())
1623    }
1624
1625    fn case_test_batch1() -> Result<RecordBatch> {
1626        let schema = Schema::new(vec![
1627            Field::new("a", DataType::Int32, true),
1628            Field::new("b", DataType::Int32, true),
1629            Field::new("c", DataType::Int32, true),
1630        ]);
1631        let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]);
1632        let b = Int32Array::from(vec![Some(3), None, Some(14), Some(7)]);
1633        let c = Int32Array::from(vec![Some(0), Some(-3), Some(777), None]);
1634        let batch = RecordBatch::try_new(
1635            Arc::new(schema),
1636            vec![Arc::new(a), Arc::new(b), Arc::new(c)],
1637        )?;
1638        Ok(batch)
1639    }
1640
1641    #[test]
1642    fn case_without_expr_else() -> Result<()> {
1643        let batch = case_test_batch()?;
1644        let schema = batch.schema();
1645
1646        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 999 END
1647        let when1 = binary(
1648            col("a", &schema)?,
1649            Operator::Eq,
1650            lit("foo"),
1651            &batch.schema(),
1652        )?;
1653        let then1 = lit(123i32);
1654        let when2 = binary(
1655            col("a", &schema)?,
1656            Operator::Eq,
1657            lit("bar"),
1658            &batch.schema(),
1659        )?;
1660        let then2 = lit(456i32);
1661        let else_value = lit(999i32);
1662
1663        let expr = generate_case_when_with_type_coercion(
1664            None,
1665            vec![(when1, then1), (when2, then2)],
1666            Some(else_value),
1667            schema.as_ref(),
1668        )?;
1669        let result = expr
1670            .evaluate(&batch)?
1671            .into_array(batch.num_rows())
1672            .expect("Failed to convert to array");
1673        let result = as_int32_array(&result)?;
1674
1675        let expected =
1676            &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
1677
1678        assert_eq!(expected, result);
1679
1680        Ok(())
1681    }
1682
1683    #[test]
1684    fn case_with_type_cast() -> Result<()> {
1685        let batch = case_test_batch()?;
1686        let schema = batch.schema();
1687
1688        // CASE WHEN a = 'foo' THEN 123.3 ELSE 999 END
1689        let when = binary(
1690            col("a", &schema)?,
1691            Operator::Eq,
1692            lit("foo"),
1693            &batch.schema(),
1694        )?;
1695        let then = lit(123.3f64);
1696        let else_value = lit(999i32);
1697
1698        let expr = generate_case_when_with_type_coercion(
1699            None,
1700            vec![(when, then)],
1701            Some(else_value),
1702            schema.as_ref(),
1703        )?;
1704        let result = expr
1705            .evaluate(&batch)?
1706            .into_array(batch.num_rows())
1707            .expect("Failed to convert to array");
1708        let result =
1709            as_float64_array(&result).expect("failed to downcast to Float64Array");
1710
1711        let expected =
1712            &Float64Array::from(vec![Some(123.3), Some(999.0), Some(999.0), Some(999.0)]);
1713
1714        assert_eq!(expected, result);
1715
1716        Ok(())
1717    }
1718
1719    #[test]
1720    fn case_with_matches_and_nulls() -> Result<()> {
1721        let batch = case_test_batch_nulls()?;
1722        let schema = batch.schema();
1723
1724        // SELECT CASE WHEN load4 = 1.77 THEN load4 END
1725        let when = binary(
1726            col("load4", &schema)?,
1727            Operator::Eq,
1728            lit(1.77f64),
1729            &batch.schema(),
1730        )?;
1731        let then = col("load4", &schema)?;
1732
1733        let expr = generate_case_when_with_type_coercion(
1734            None,
1735            vec![(when, then)],
1736            None,
1737            schema.as_ref(),
1738        )?;
1739        let result = expr
1740            .evaluate(&batch)?
1741            .into_array(batch.num_rows())
1742            .expect("Failed to convert to array");
1743        let result =
1744            as_float64_array(&result).expect("failed to downcast to Float64Array");
1745
1746        let expected =
1747            &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
1748
1749        assert_eq!(expected, result);
1750
1751        Ok(())
1752    }
1753
1754    #[test]
1755    fn case_with_scalar_predicate() -> Result<()> {
1756        let batch = case_test_batch_nulls()?;
1757        let schema = batch.schema();
1758
1759        // SELECT CASE WHEN TRUE THEN load4 END
1760        let when = lit(true);
1761        let then = col("load4", &schema)?;
1762        let expr = generate_case_when_with_type_coercion(
1763            None,
1764            vec![(when, then)],
1765            None,
1766            schema.as_ref(),
1767        )?;
1768
1769        // many rows
1770        let result = expr
1771            .evaluate(&batch)?
1772            .into_array(batch.num_rows())
1773            .expect("Failed to convert to array");
1774        let result =
1775            as_float64_array(&result).expect("failed to downcast to Float64Array");
1776        let expected = &Float64Array::from(vec![
1777            Some(1.77),
1778            None,
1779            None,
1780            Some(1.78),
1781            None,
1782            Some(1.77),
1783        ]);
1784        assert_eq!(expected, result);
1785
1786        // one row
1787        let expected = Float64Array::from(vec![Some(1.1)]);
1788        let batch =
1789            RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(expected.clone())])?;
1790        let result = expr
1791            .evaluate(&batch)?
1792            .into_array(batch.num_rows())
1793            .expect("Failed to convert to array");
1794        let result =
1795            as_float64_array(&result).expect("failed to downcast to Float64Array");
1796        assert_eq!(&expected, result);
1797
1798        Ok(())
1799    }
1800
1801    #[test]
1802    fn case_expr_matches_and_nulls() -> Result<()> {
1803        let batch = case_test_batch_nulls()?;
1804        let schema = batch.schema();
1805
1806        // SELECT CASE load4 WHEN 1.77 THEN load4 END
1807        let expr = col("load4", &schema)?;
1808        let when = lit(1.77f64);
1809        let then = col("load4", &schema)?;
1810
1811        let expr = generate_case_when_with_type_coercion(
1812            Some(expr),
1813            vec![(when, then)],
1814            None,
1815            schema.as_ref(),
1816        )?;
1817        let result = expr
1818            .evaluate(&batch)?
1819            .into_array(batch.num_rows())
1820            .expect("Failed to convert to array");
1821        let result =
1822            as_float64_array(&result).expect("failed to downcast to Float64Array");
1823
1824        let expected =
1825            &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
1826
1827        assert_eq!(expected, result);
1828
1829        Ok(())
1830    }
1831
1832    #[test]
1833    fn test_when_null_and_some_cond_else_null() -> Result<()> {
1834        let batch = case_test_batch()?;
1835        let schema = batch.schema();
1836
1837        let when = binary(
1838            Arc::new(Literal::new(ScalarValue::Boolean(None))),
1839            Operator::And,
1840            binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?,
1841            &schema,
1842        )?;
1843        let then = col("a", &schema)?;
1844
1845        // SELECT CASE WHEN (NULL AND a = 'foo') THEN a ELSE NULL END
1846        let expr = Arc::new(CaseExpr::try_new(None, vec![(when, then)], None)?);
1847        let result = expr
1848            .evaluate(&batch)?
1849            .into_array(batch.num_rows())
1850            .expect("Failed to convert to array");
1851        let result = as_string_array(&result);
1852
1853        // all result values should be null
1854        assert_eq!(result.logical_null_count(), batch.num_rows());
1855        Ok(())
1856    }
1857
1858    fn case_test_batch() -> Result<RecordBatch> {
1859        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
1860        let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
1861        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
1862        Ok(batch)
1863    }
1864
1865    // Construct an array that has several NULL values whose
1866    // underlying buffer actually matches the where expr predicate
1867    fn case_test_batch_nulls() -> Result<RecordBatch> {
1868        let load4: Float64Array = vec![
1869            Some(1.77), // 1.77
1870            Some(1.77), // null <-- same value, but will be set to null
1871            Some(1.77), // null <-- same value, but will be set to null
1872            Some(1.78), // 1.78
1873            None,       // null
1874            Some(1.77), // 1.77
1875        ]
1876        .into_iter()
1877        .collect();
1878
1879        let null_buffer = Buffer::from([0b00101001u8]);
1880        let load4 = load4
1881            .into_data()
1882            .into_builder()
1883            .null_bit_buffer(Some(null_buffer))
1884            .build()
1885            .unwrap();
1886        let load4: Float64Array = load4.into();
1887
1888        let batch =
1889            RecordBatch::try_from_iter(vec![("load4", Arc::new(load4) as ArrayRef)])?;
1890        Ok(batch)
1891    }
1892
1893    #[test]
1894    fn case_test_incompatible() -> Result<()> {
1895        // 1 then is int64
1896        // 2 then is boolean
1897        let batch = case_test_batch()?;
1898        let schema = batch.schema();
1899
1900        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN true END
1901        let when1 = binary(
1902            col("a", &schema)?,
1903            Operator::Eq,
1904            lit("foo"),
1905            &batch.schema(),
1906        )?;
1907        let then1 = lit(123i32);
1908        let when2 = binary(
1909            col("a", &schema)?,
1910            Operator::Eq,
1911            lit("bar"),
1912            &batch.schema(),
1913        )?;
1914        let then2 = lit(true);
1915
1916        let expr = generate_case_when_with_type_coercion(
1917            None,
1918            vec![(when1, then1), (when2, then2)],
1919            None,
1920            schema.as_ref(),
1921        );
1922        assert!(expr.is_err());
1923
1924        // then 1 is int32
1925        // then 2 is int64
1926        // else is float
1927        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 1.23 END
1928        let when1 = binary(
1929            col("a", &schema)?,
1930            Operator::Eq,
1931            lit("foo"),
1932            &batch.schema(),
1933        )?;
1934        let then1 = lit(123i32);
1935        let when2 = binary(
1936            col("a", &schema)?,
1937            Operator::Eq,
1938            lit("bar"),
1939            &batch.schema(),
1940        )?;
1941        let then2 = lit(456i64);
1942        let else_expr = lit(1.23f64);
1943
1944        let expr = generate_case_when_with_type_coercion(
1945            None,
1946            vec![(when1, then1), (when2, then2)],
1947            Some(else_expr),
1948            schema.as_ref(),
1949        );
1950        assert!(expr.is_ok());
1951        let result_type = expr.unwrap().data_type(schema.as_ref())?;
1952        assert_eq!(Float64, result_type);
1953        Ok(())
1954    }
1955
1956    #[test]
1957    fn case_eq() -> Result<()> {
1958        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1959
1960        let when1 = lit("foo");
1961        let then1 = lit(123i32);
1962        let when2 = lit("bar");
1963        let then2 = lit(456i32);
1964        let else_value = lit(999i32);
1965
1966        let expr1 = generate_case_when_with_type_coercion(
1967            Some(col("a", &schema)?),
1968            vec![
1969                (Arc::clone(&when1), Arc::clone(&then1)),
1970                (Arc::clone(&when2), Arc::clone(&then2)),
1971            ],
1972            Some(Arc::clone(&else_value)),
1973            &schema,
1974        )?;
1975
1976        let expr2 = generate_case_when_with_type_coercion(
1977            Some(col("a", &schema)?),
1978            vec![
1979                (Arc::clone(&when1), Arc::clone(&then1)),
1980                (Arc::clone(&when2), Arc::clone(&then2)),
1981            ],
1982            Some(Arc::clone(&else_value)),
1983            &schema,
1984        )?;
1985
1986        let expr3 = generate_case_when_with_type_coercion(
1987            Some(col("a", &schema)?),
1988            vec![(Arc::clone(&when1), Arc::clone(&then1)), (when2, then2)],
1989            None,
1990            &schema,
1991        )?;
1992
1993        let expr4 = generate_case_when_with_type_coercion(
1994            Some(col("a", &schema)?),
1995            vec![(when1, then1)],
1996            Some(else_value),
1997            &schema,
1998        )?;
1999
2000        assert!(expr1.eq(&expr2));
2001        assert!(expr2.eq(&expr1));
2002
2003        assert!(expr2.ne(&expr3));
2004        assert!(expr3.ne(&expr2));
2005
2006        assert!(expr1.ne(&expr4));
2007        assert!(expr4.ne(&expr1));
2008
2009        Ok(())
2010    }
2011
2012    #[test]
2013    fn case_transform() -> Result<()> {
2014        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
2015
2016        let when1 = lit("foo");
2017        let then1 = lit(123i32);
2018        let when2 = lit("bar");
2019        let then2 = lit(456i32);
2020        let else_value = lit(999i32);
2021
2022        let expr = generate_case_when_with_type_coercion(
2023            Some(col("a", &schema)?),
2024            vec![
2025                (Arc::clone(&when1), Arc::clone(&then1)),
2026                (Arc::clone(&when2), Arc::clone(&then2)),
2027            ],
2028            Some(Arc::clone(&else_value)),
2029            &schema,
2030        )?;
2031
2032        let expr2 = Arc::clone(&expr)
2033            .transform(|e| {
2034                let transformed = match e.as_any().downcast_ref::<Literal>() {
2035                    Some(lit_value) => match lit_value.value() {
2036                        ScalarValue::Utf8(Some(str_value)) => {
2037                            Some(lit(str_value.to_uppercase()))
2038                        }
2039                        _ => None,
2040                    },
2041                    _ => None,
2042                };
2043                Ok(if let Some(transformed) = transformed {
2044                    Transformed::yes(transformed)
2045                } else {
2046                    Transformed::no(e)
2047                })
2048            })
2049            .data()
2050            .unwrap();
2051
2052        let expr3 = Arc::clone(&expr)
2053            .transform_down(|e| {
2054                let transformed = match e.as_any().downcast_ref::<Literal>() {
2055                    Some(lit_value) => match lit_value.value() {
2056                        ScalarValue::Utf8(Some(str_value)) => {
2057                            Some(lit(str_value.to_uppercase()))
2058                        }
2059                        _ => None,
2060                    },
2061                    _ => None,
2062                };
2063                Ok(if let Some(transformed) = transformed {
2064                    Transformed::yes(transformed)
2065                } else {
2066                    Transformed::no(e)
2067                })
2068            })
2069            .data()
2070            .unwrap();
2071
2072        assert!(expr.ne(&expr2));
2073        assert!(expr2.eq(&expr3));
2074
2075        Ok(())
2076    }
2077
2078    #[test]
2079    fn test_column_or_null_specialization() -> Result<()> {
2080        // create input data
2081        let mut c1 = Int32Builder::new();
2082        let mut c2 = StringBuilder::new();
2083        for i in 0..1000 {
2084            c1.append_value(i);
2085            if i % 7 == 0 {
2086                c2.append_null();
2087            } else {
2088                c2.append_value(format!("string {i}"));
2089            }
2090        }
2091        let c1 = Arc::new(c1.finish());
2092        let c2 = Arc::new(c2.finish());
2093        let schema = Schema::new(vec![
2094            Field::new("c1", DataType::Int32, true),
2095            Field::new("c2", DataType::Utf8, true),
2096        ]);
2097        let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap();
2098
2099        // CaseWhenExprOrNull should produce same results as CaseExpr
2100        let predicate = Arc::new(BinaryExpr::new(
2101            make_col("c1", 0),
2102            Operator::LtEq,
2103            make_lit_i32(250),
2104        ));
2105        let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?;
2106        assert!(matches!(expr.eval_method, EvalMethod::InfallibleExprOrNull));
2107        match expr.evaluate(&batch)? {
2108            ColumnarValue::Array(array) => {
2109                assert_eq!(1000, array.len());
2110                assert_eq!(785, array.null_count());
2111            }
2112            _ => unreachable!(),
2113        }
2114        Ok(())
2115    }
2116
2117    #[test]
2118    fn test_expr_or_expr_specialization() -> Result<()> {
2119        let batch = case_test_batch1()?;
2120        let schema = batch.schema();
2121        let when = binary(
2122            col("a", &schema)?,
2123            Operator::LtEq,
2124            lit(2i32),
2125            &batch.schema(),
2126        )?;
2127        let then = col("b", &schema)?;
2128        let else_expr = col("c", &schema)?;
2129        let expr = CaseExpr::try_new(None, vec![(when, then)], Some(else_expr))?;
2130        assert!(matches!(
2131            expr.eval_method,
2132            EvalMethod::ExpressionOrExpression(_)
2133        ));
2134        let result = expr
2135            .evaluate(&batch)?
2136            .into_array(batch.num_rows())
2137            .expect("Failed to convert to array");
2138        let result = as_int32_array(&result).expect("failed to downcast to Int32Array");
2139
2140        let expected = &Int32Array::from(vec![Some(3), None, Some(777), None]);
2141
2142        assert_eq!(expected, result);
2143        Ok(())
2144    }
2145
2146    fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
2147        Arc::new(Column::new(name, index))
2148    }
2149
2150    fn make_lit_i32(n: i32) -> Arc<dyn PhysicalExpr> {
2151        Arc::new(Literal::new(ScalarValue::Int32(Some(n))))
2152    }
2153
2154    fn generate_case_when_with_type_coercion(
2155        expr: Option<Arc<dyn PhysicalExpr>>,
2156        when_thens: Vec<WhenThen>,
2157        else_expr: Option<Arc<dyn PhysicalExpr>>,
2158        input_schema: &Schema,
2159    ) -> Result<Arc<dyn PhysicalExpr>> {
2160        let coerce_type =
2161            get_case_common_type(&when_thens, else_expr.clone(), input_schema);
2162        let (when_thens, else_expr) = match coerce_type {
2163            None => plan_err!(
2164                "Can't get a common type for then {when_thens:?} and else {else_expr:?} expression"
2165            ),
2166            Some(data_type) => {
2167                // cast then expr
2168                let left = when_thens
2169                    .into_iter()
2170                    .map(|(when, then)| {
2171                        let then = try_cast(then, input_schema, data_type.clone())?;
2172                        Ok((when, then))
2173                    })
2174                    .collect::<Result<Vec<_>>>()?;
2175                let right = match else_expr {
2176                    None => None,
2177                    Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?),
2178                };
2179
2180                Ok((left, right))
2181            }
2182        }?;
2183        case(expr, when_thens, else_expr)
2184    }
2185
2186    fn get_case_common_type(
2187        when_thens: &[WhenThen],
2188        else_expr: Option<Arc<dyn PhysicalExpr>>,
2189        input_schema: &Schema,
2190    ) -> Option<DataType> {
2191        let thens_type = when_thens
2192            .iter()
2193            .map(|when_then| {
2194                let data_type = &when_then.1.data_type(input_schema).unwrap();
2195                data_type.clone()
2196            })
2197            .collect::<Vec<_>>();
2198        let else_type = match else_expr {
2199            None => {
2200                // case when then exprs must have one then value
2201                thens_type[0].clone()
2202            }
2203            Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(),
2204        };
2205        thens_type
2206            .iter()
2207            .try_fold(else_type, |left_type, right_type| {
2208                // TODO: now just use the `equal` coercion rule for case when. If find the issue, and
2209                // refactor again.
2210                comparison_coercion(&left_type, right_type)
2211            })
2212    }
2213
2214    #[test]
2215    fn test_fmt_sql() -> Result<()> {
2216        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
2217
2218        // CASE WHEN a = 'foo' THEN 123.3 ELSE 999 END
2219        let when = binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?;
2220        let then = lit(123.3f64);
2221        let else_value = lit(999i32);
2222
2223        let expr = generate_case_when_with_type_coercion(
2224            None,
2225            vec![(when, then)],
2226            Some(else_value),
2227            &schema,
2228        )?;
2229
2230        let display_string = expr.to_string();
2231        assert_eq!(
2232            display_string,
2233            "CASE WHEN a@0 = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END"
2234        );
2235
2236        let sql_string = fmt_sql(expr.as_ref()).to_string();
2237        assert_eq!(
2238            sql_string,
2239            "CASE WHEN a = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END"
2240        );
2241
2242        Ok(())
2243    }
2244
2245    #[test]
2246    fn test_merge_n() {
2247        let a1 = StringArray::from(vec![Some("A")]).to_data();
2248        let a2 = StringArray::from(vec![Some("B")]).to_data();
2249        let a3 = StringArray::from(vec![Some("C"), Some("D")]).to_data();
2250
2251        let indices = vec![
2252            PartialResultIndex::none(),
2253            PartialResultIndex::try_new(1).unwrap(),
2254            PartialResultIndex::try_new(0).unwrap(),
2255            PartialResultIndex::none(),
2256            PartialResultIndex::try_new(2).unwrap(),
2257            PartialResultIndex::try_new(2).unwrap(),
2258        ];
2259
2260        let merged = merge_n(&[a1, a2, a3], &indices).unwrap();
2261        let merged = merged.as_string::<i32>();
2262
2263        assert_eq!(merged.len(), indices.len());
2264        assert!(!merged.is_valid(0));
2265        assert!(merged.is_valid(1));
2266        assert_eq!(merged.value(1), "B");
2267        assert!(merged.is_valid(2));
2268        assert_eq!(merged.value(2), "A");
2269        assert!(!merged.is_valid(3));
2270        assert!(merged.is_valid(4));
2271        assert_eq!(merged.value(4), "C");
2272        assert!(merged.is_valid(5));
2273        assert_eq!(merged.value(5), "D");
2274    }
2275
2276    #[test]
2277    fn test_merge() {
2278        let a1 = Arc::new(StringArray::from(vec![Some("A"), Some("C")]));
2279        let a2 = Arc::new(StringArray::from(vec![Some("B")]));
2280
2281        let mask = BooleanArray::from(vec![true, false, true]);
2282
2283        let merged =
2284            merge(&mask, ColumnarValue::Array(a1), ColumnarValue::Array(a2)).unwrap();
2285        let merged = merged.as_string::<i32>();
2286
2287        assert_eq!(merged.len(), mask.len());
2288        assert!(merged.is_valid(0));
2289        assert_eq!(merged.value(0), "A");
2290        assert!(merged.is_valid(1));
2291        assert_eq!(merged.value(1), "B");
2292        assert!(merged.is_valid(2));
2293        assert_eq!(merged.value(2), "C");
2294    }
2295}