datafusion_physical_expr/expressions/
case.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::expressions::try_cast;
19use crate::PhysicalExpr;
20use std::borrow::Cow;
21use std::hash::Hash;
22use std::{any::Any, sync::Arc};
23
24use arrow::array::*;
25use arrow::compute::kernels::zip::zip;
26use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter};
27use arrow::datatypes::{DataType, Schema};
28use datafusion_common::cast::as_boolean_array;
29use datafusion_common::{
30    exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
31};
32use datafusion_expr::ColumnarValue;
33
34use super::{Column, Literal};
35use datafusion_physical_expr_common::datum::compare_with_eq;
36use itertools::Itertools;
37
38type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
39
40#[derive(Debug, Hash, PartialEq, Eq)]
41enum EvalMethod {
42    /// CASE WHEN condition THEN result
43    ///      [WHEN ...]
44    ///      [ELSE result]
45    /// END
46    NoExpression,
47    /// CASE expression
48    ///     WHEN value THEN result
49    ///     [WHEN ...]
50    ///     [ELSE result]
51    /// END
52    WithExpression,
53    /// This is a specialization for a specific use case where we can take a fast path
54    /// for expressions that are infallible and can be cheaply computed for the entire
55    /// record batch rather than just for the rows where the predicate is true.
56    ///
57    /// CASE WHEN condition THEN column [ELSE NULL] END
58    InfallibleExprOrNull,
59    /// This is a specialization for a specific use case where we can take a fast path
60    /// if there is just one when/then pair and both the `then` and `else` expressions
61    /// are literal values
62    /// CASE WHEN condition THEN literal ELSE literal END
63    ScalarOrScalar,
64    /// This is a specialization for a specific use case where we can take a fast path
65    /// if there is just one when/then pair and both the `then` and `else` are expressions
66    ///
67    /// CASE WHEN condition THEN expression ELSE expression END
68    ExpressionOrExpression,
69}
70
71/// The CASE expression is similar to a series of nested if/else and there are two forms that
72/// can be used. The first form consists of a series of boolean "when" expressions with
73/// corresponding "then" expressions, and an optional "else" expression.
74///
75/// CASE WHEN condition THEN result
76///      [WHEN ...]
77///      [ELSE result]
78/// END
79///
80/// The second form uses a base expression and then a series of "when" clauses that match on a
81/// literal value.
82///
83/// CASE expression
84///     WHEN value THEN result
85///     [WHEN ...]
86///     [ELSE result]
87/// END
88#[derive(Debug, Hash, PartialEq, Eq)]
89pub struct CaseExpr {
90    /// Optional base expression that can be compared to literal values in the "when" expressions
91    expr: Option<Arc<dyn PhysicalExpr>>,
92    /// One or more when/then expressions
93    when_then_expr: Vec<WhenThen>,
94    /// Optional "else" expression
95    else_expr: Option<Arc<dyn PhysicalExpr>>,
96    /// Evaluation method to use
97    eval_method: EvalMethod,
98}
99
100impl std::fmt::Display for CaseExpr {
101    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
102        write!(f, "CASE ")?;
103        if let Some(e) = &self.expr {
104            write!(f, "{e} ")?;
105        }
106        for (w, t) in &self.when_then_expr {
107            write!(f, "WHEN {w} THEN {t} ")?;
108        }
109        if let Some(e) = &self.else_expr {
110            write!(f, "ELSE {e} ")?;
111        }
112        write!(f, "END")
113    }
114}
115
116/// This is a specialization for a specific use case where we can take a fast path
117/// for expressions that are infallible and can be cheaply computed for the entire
118/// record batch rather than just for the rows where the predicate is true. For now,
119/// this is limited to use with Column expressions but could potentially be used for other
120/// expressions in the future
121fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) -> bool {
122    expr.as_any().is::<Column>()
123}
124
125impl CaseExpr {
126    /// Create a new CASE WHEN expression
127    pub fn try_new(
128        expr: Option<Arc<dyn PhysicalExpr>>,
129        when_then_expr: Vec<WhenThen>,
130        else_expr: Option<Arc<dyn PhysicalExpr>>,
131    ) -> Result<Self> {
132        // normalize null literals to None in the else_expr (this already happens
133        // during SQL planning, but not necessarily for other use cases)
134        let else_expr = match &else_expr {
135            Some(e) => match e.as_any().downcast_ref::<Literal>() {
136                Some(lit) if lit.value().is_null() => None,
137                _ => else_expr,
138            },
139            _ => else_expr,
140        };
141
142        if when_then_expr.is_empty() {
143            exec_err!("There must be at least one WHEN clause")
144        } else {
145            let eval_method = if expr.is_some() {
146                EvalMethod::WithExpression
147            } else if when_then_expr.len() == 1
148                && is_cheap_and_infallible(&(when_then_expr[0].1))
149                && else_expr.is_none()
150            {
151                EvalMethod::InfallibleExprOrNull
152            } else if when_then_expr.len() == 1
153                && when_then_expr[0].1.as_any().is::<Literal>()
154                && else_expr.is_some()
155                && else_expr.as_ref().unwrap().as_any().is::<Literal>()
156            {
157                EvalMethod::ScalarOrScalar
158            } else if when_then_expr.len() == 1
159                && is_cheap_and_infallible(&(when_then_expr[0].1))
160                && else_expr.as_ref().is_some_and(is_cheap_and_infallible)
161            {
162                EvalMethod::ExpressionOrExpression
163            } else {
164                EvalMethod::NoExpression
165            };
166
167            Ok(Self {
168                expr,
169                when_then_expr,
170                else_expr,
171                eval_method,
172            })
173        }
174    }
175
176    /// Optional base expression that can be compared to literal values in the "when" expressions
177    pub fn expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
178        self.expr.as_ref()
179    }
180
181    /// One or more when/then expressions
182    pub fn when_then_expr(&self) -> &[WhenThen] {
183        &self.when_then_expr
184    }
185
186    /// Optional "else" expression
187    pub fn else_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
188        self.else_expr.as_ref()
189    }
190}
191
192impl CaseExpr {
193    /// This function evaluates the form of CASE that matches an expression to fixed values.
194    ///
195    /// CASE expression
196    ///     WHEN value THEN result
197    ///     [WHEN ...]
198    ///     [ELSE result]
199    /// END
200    fn case_when_with_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
201        let return_type = self.data_type(&batch.schema())?;
202        let expr = self.expr.as_ref().unwrap();
203        let base_value = expr.evaluate(batch)?;
204        let base_value = base_value.into_array(batch.num_rows())?;
205        let base_nulls = is_null(base_value.as_ref())?;
206
207        // start with nulls as default output
208        let mut current_value = new_null_array(&return_type, batch.num_rows());
209        // We only consider non-null values while comparing with whens
210        let mut remainder = not(&base_nulls)?;
211        for i in 0..self.when_then_expr.len() {
212            let when_value = self.when_then_expr[i]
213                .0
214                .evaluate_selection(batch, &remainder)?;
215            let when_value = when_value.into_array(batch.num_rows())?;
216            // build boolean array representing which rows match the "when" value
217            let when_match = compare_with_eq(
218                &when_value,
219                &base_value,
220                // The types of case and when expressions will be coerced to match.
221                // We only need to check if the base_value is nested.
222                base_value.data_type().is_nested(),
223            )?;
224            // Treat nulls as false
225            let when_match = match when_match.null_count() {
226                0 => Cow::Borrowed(&when_match),
227                _ => Cow::Owned(prep_null_mask_filter(&when_match)),
228            };
229            // Make sure we only consider rows that have not been matched yet
230            let when_match = and(&when_match, &remainder)?;
231
232            // When no rows available for when clause, skip then clause
233            if when_match.true_count() == 0 {
234                continue;
235            }
236
237            let then_value = self.when_then_expr[i]
238                .1
239                .evaluate_selection(batch, &when_match)?;
240
241            current_value = match then_value {
242                ColumnarValue::Scalar(ScalarValue::Null) => {
243                    nullif(current_value.as_ref(), &when_match)?
244                }
245                ColumnarValue::Scalar(then_value) => {
246                    zip(&when_match, &then_value.to_scalar()?, &current_value)?
247                }
248                ColumnarValue::Array(then_value) => {
249                    zip(&when_match, &then_value, &current_value)?
250                }
251            };
252
253            remainder = and_not(&remainder, &when_match)?;
254        }
255
256        if let Some(e) = self.else_expr() {
257            // keep `else_expr`'s data type and return type consistent
258            let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
259            // null and unmatched tuples should be assigned else value
260            remainder = or(&base_nulls, &remainder)?;
261            let else_ = expr
262                .evaluate_selection(batch, &remainder)?
263                .into_array(batch.num_rows())?;
264            current_value = zip(&remainder, &else_, &current_value)?;
265        }
266
267        Ok(ColumnarValue::Array(current_value))
268    }
269
270    /// This function evaluates the form of CASE where each WHEN expression is a boolean
271    /// expression.
272    ///
273    /// CASE WHEN condition THEN result
274    ///      [WHEN ...]
275    ///      [ELSE result]
276    /// END
277    fn case_when_no_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
278        let return_type = self.data_type(&batch.schema())?;
279
280        // start with nulls as default output
281        let mut current_value = new_null_array(&return_type, batch.num_rows());
282        let mut remainder = BooleanArray::from(vec![true; batch.num_rows()]);
283        for i in 0..self.when_then_expr.len() {
284            let when_value = self.when_then_expr[i]
285                .0
286                .evaluate_selection(batch, &remainder)?;
287            let when_value = when_value.into_array(batch.num_rows())?;
288            let when_value = as_boolean_array(&when_value).map_err(|_| {
289                internal_datafusion_err!("WHEN expression did not return a BooleanArray")
290            })?;
291            // Treat 'NULL' as false value
292            let when_value = match when_value.null_count() {
293                0 => Cow::Borrowed(when_value),
294                _ => Cow::Owned(prep_null_mask_filter(when_value)),
295            };
296            // Make sure we only consider rows that have not been matched yet
297            let when_value = and(&when_value, &remainder)?;
298
299            // When no rows available for when clause, skip then clause
300            if when_value.true_count() == 0 {
301                continue;
302            }
303
304            let then_value = self.when_then_expr[i]
305                .1
306                .evaluate_selection(batch, &when_value)?;
307
308            current_value = match then_value {
309                ColumnarValue::Scalar(ScalarValue::Null) => {
310                    nullif(current_value.as_ref(), &when_value)?
311                }
312                ColumnarValue::Scalar(then_value) => {
313                    zip(&when_value, &then_value.to_scalar()?, &current_value)?
314                }
315                ColumnarValue::Array(then_value) => {
316                    zip(&when_value, &then_value, &current_value)?
317                }
318            };
319
320            // Succeed tuples should be filtered out for short-circuit evaluation,
321            // null values for the current when expr should be kept
322            remainder = and_not(&remainder, &when_value)?;
323        }
324
325        if let Some(e) = self.else_expr() {
326            // keep `else_expr`'s data type and return type consistent
327            let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
328            let else_ = expr
329                .evaluate_selection(batch, &remainder)?
330                .into_array(batch.num_rows())?;
331            current_value = zip(&remainder, &else_, &current_value)?;
332        }
333
334        Ok(ColumnarValue::Array(current_value))
335    }
336
337    /// This function evaluates the specialized case of:
338    ///
339    /// CASE WHEN condition THEN column
340    ///      [ELSE NULL]
341    /// END
342    ///
343    /// Note that this function is only safe to use for "then" expressions
344    /// that are infallible because the expression will be evaluated for all
345    /// rows in the input batch.
346    fn case_column_or_null(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
347        let when_expr = &self.when_then_expr[0].0;
348        let then_expr = &self.when_then_expr[0].1;
349
350        match when_expr.evaluate(batch)? {
351            // WHEN true --> column
352            ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))) => {
353                then_expr.evaluate(batch)
354            }
355            // WHEN [false | null] --> NULL
356            ColumnarValue::Scalar(_) => {
357                // return scalar NULL value
358                ScalarValue::try_from(self.data_type(&batch.schema())?)
359                    .map(ColumnarValue::Scalar)
360            }
361            // WHEN column --> column
362            ColumnarValue::Array(bit_mask) => {
363                let bit_mask = bit_mask
364                    .as_any()
365                    .downcast_ref::<BooleanArray>()
366                    .expect("predicate should evaluate to a boolean array");
367                // invert the bitmask
368                let bit_mask = match bit_mask.null_count() {
369                    0 => not(bit_mask)?,
370                    _ => not(&prep_null_mask_filter(bit_mask))?,
371                };
372                match then_expr.evaluate(batch)? {
373                    ColumnarValue::Array(array) => {
374                        Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?))
375                    }
376                    ColumnarValue::Scalar(_) => {
377                        internal_err!("expression did not evaluate to an array")
378                    }
379                }
380            }
381        }
382    }
383
384    fn scalar_or_scalar(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
385        let return_type = self.data_type(&batch.schema())?;
386
387        // evaluate when expression
388        let when_value = self.when_then_expr[0].0.evaluate(batch)?;
389        let when_value = when_value.into_array(batch.num_rows())?;
390        let when_value = as_boolean_array(&when_value).map_err(|_| {
391            internal_datafusion_err!("WHEN expression did not return a BooleanArray")
392        })?;
393
394        // Treat 'NULL' as false value
395        let when_value = match when_value.null_count() {
396            0 => Cow::Borrowed(when_value),
397            _ => Cow::Owned(prep_null_mask_filter(when_value)),
398        };
399
400        // evaluate then_value
401        let then_value = self.when_then_expr[0].1.evaluate(batch)?;
402        let then_value = Scalar::new(then_value.into_array(1)?);
403
404        let Some(e) = self.else_expr() else {
405            return internal_err!("expression did not evaluate to an array");
406        };
407        // keep `else_expr`'s data type and return type consistent
408        let expr = try_cast(Arc::clone(e), &batch.schema(), return_type)?;
409        let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?);
410        Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?))
411    }
412
413    fn expr_or_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
414        let return_type = self.data_type(&batch.schema())?;
415
416        // evalute when condition on batch
417        let when_value = self.when_then_expr[0].0.evaluate(batch)?;
418        let when_value = when_value.into_array(batch.num_rows())?;
419        let when_value = as_boolean_array(&when_value).map_err(|e| {
420            DataFusionError::Context(
421                "WHEN expression did not return a BooleanArray".to_string(),
422                Box::new(e),
423            )
424        })?;
425
426        // Treat 'NULL' as false value
427        let when_value = match when_value.null_count() {
428            0 => Cow::Borrowed(when_value),
429            _ => Cow::Owned(prep_null_mask_filter(when_value)),
430        };
431
432        let then_value = self.when_then_expr[0]
433            .1
434            .evaluate_selection(batch, &when_value)?
435            .into_array(batch.num_rows())?;
436
437        // evaluate else expression on the values not covered by when_value
438        let remainder = not(&when_value)?;
439        let e = self.else_expr.as_ref().unwrap();
440        // keep `else_expr`'s data type and return type consistent
441        let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())
442            .unwrap_or_else(|_| Arc::clone(e));
443        let else_ = expr
444            .evaluate_selection(batch, &remainder)?
445            .into_array(batch.num_rows())?;
446
447        Ok(ColumnarValue::Array(zip(&remainder, &else_, &then_value)?))
448    }
449}
450
451impl PhysicalExpr for CaseExpr {
452    /// Return a reference to Any that can be used for down-casting
453    fn as_any(&self) -> &dyn Any {
454        self
455    }
456
457    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
458        // since all then results have the same data type, we can choose any one as the
459        // return data type except for the null.
460        let mut data_type = DataType::Null;
461        for i in 0..self.when_then_expr.len() {
462            data_type = self.when_then_expr[i].1.data_type(input_schema)?;
463            if !data_type.equals_datatype(&DataType::Null) {
464                break;
465            }
466        }
467        // if all then results are null, we use data type of else expr instead if possible.
468        if data_type.equals_datatype(&DataType::Null) {
469            if let Some(e) = &self.else_expr {
470                data_type = e.data_type(input_schema)?;
471            }
472        }
473
474        Ok(data_type)
475    }
476
477    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
478        // this expression is nullable if any of the input expressions are nullable
479        let then_nullable = self
480            .when_then_expr
481            .iter()
482            .map(|(_, t)| t.nullable(input_schema))
483            .collect::<Result<Vec<_>>>()?;
484        if then_nullable.contains(&true) {
485            Ok(true)
486        } else if let Some(e) = &self.else_expr {
487            e.nullable(input_schema)
488        } else {
489            // CASE produces NULL if there is no `else` expr
490            // (aka when none of the `when_then_exprs` match)
491            Ok(true)
492        }
493    }
494
495    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
496        match self.eval_method {
497            EvalMethod::WithExpression => {
498                // this use case evaluates "expr" and then compares the values with the "when"
499                // values
500                self.case_when_with_expr(batch)
501            }
502            EvalMethod::NoExpression => {
503                // The "when" conditions all evaluate to boolean in this use case and can be
504                // arbitrary expressions
505                self.case_when_no_expr(batch)
506            }
507            EvalMethod::InfallibleExprOrNull => {
508                // Specialization for CASE WHEN expr THEN column [ELSE NULL] END
509                self.case_column_or_null(batch)
510            }
511            EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch),
512            EvalMethod::ExpressionOrExpression => self.expr_or_expr(batch),
513        }
514    }
515
516    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
517        let mut children = vec![];
518        if let Some(expr) = &self.expr {
519            children.push(expr)
520        }
521        self.when_then_expr.iter().for_each(|(cond, value)| {
522            children.push(cond);
523            children.push(value);
524        });
525
526        if let Some(else_expr) = &self.else_expr {
527            children.push(else_expr)
528        }
529        children
530    }
531
532    // For physical CaseExpr, we do not allow modifying children size
533    fn with_new_children(
534        self: Arc<Self>,
535        children: Vec<Arc<dyn PhysicalExpr>>,
536    ) -> Result<Arc<dyn PhysicalExpr>> {
537        if children.len() != self.children().len() {
538            internal_err!("CaseExpr: Wrong number of children")
539        } else {
540            let (expr, when_then_expr, else_expr) =
541                match (self.expr().is_some(), self.else_expr().is_some()) {
542                    (true, true) => (
543                        Some(&children[0]),
544                        &children[1..children.len() - 1],
545                        Some(&children[children.len() - 1]),
546                    ),
547                    (true, false) => {
548                        (Some(&children[0]), &children[1..children.len()], None)
549                    }
550                    (false, true) => (
551                        None,
552                        &children[0..children.len() - 1],
553                        Some(&children[children.len() - 1]),
554                    ),
555                    (false, false) => (None, &children[0..children.len()], None),
556                };
557            Ok(Arc::new(CaseExpr::try_new(
558                expr.cloned(),
559                when_then_expr.iter().cloned().tuples().collect(),
560                else_expr.cloned(),
561            )?))
562        }
563    }
564
565    fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
566        write!(f, "CASE ")?;
567        if let Some(e) = &self.expr {
568            e.fmt_sql(f)?;
569            write!(f, " ")?;
570        }
571
572        for (w, t) in &self.when_then_expr {
573            write!(f, "WHEN ")?;
574            w.fmt_sql(f)?;
575            write!(f, " THEN ")?;
576            t.fmt_sql(f)?;
577            write!(f, " ")?;
578        }
579
580        if let Some(e) = &self.else_expr {
581            write!(f, "ELSE ")?;
582            e.fmt_sql(f)?;
583            write!(f, " ")?;
584        }
585        write!(f, "END")
586    }
587}
588
589/// Create a CASE expression
590pub fn case(
591    expr: Option<Arc<dyn PhysicalExpr>>,
592    when_thens: Vec<WhenThen>,
593    else_expr: Option<Arc<dyn PhysicalExpr>>,
594) -> Result<Arc<dyn PhysicalExpr>> {
595    Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?))
596}
597
598#[cfg(test)]
599mod tests {
600    use super::*;
601
602    use crate::expressions::{binary, cast, col, lit, BinaryExpr};
603    use arrow::buffer::Buffer;
604    use arrow::datatypes::DataType::Float64;
605    use arrow::datatypes::Field;
606    use datafusion_common::cast::{as_float64_array, as_int32_array};
607    use datafusion_common::plan_err;
608    use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
609    use datafusion_expr::type_coercion::binary::comparison_coercion;
610    use datafusion_expr::Operator;
611    use datafusion_physical_expr_common::physical_expr::fmt_sql;
612
613    #[test]
614    fn case_with_expr() -> Result<()> {
615        let batch = case_test_batch()?;
616        let schema = batch.schema();
617
618        // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END
619        let when1 = lit("foo");
620        let then1 = lit(123i32);
621        let when2 = lit("bar");
622        let then2 = lit(456i32);
623
624        let expr = generate_case_when_with_type_coercion(
625            Some(col("a", &schema)?),
626            vec![(when1, then1), (when2, then2)],
627            None,
628            schema.as_ref(),
629        )?;
630        let result = expr
631            .evaluate(&batch)?
632            .into_array(batch.num_rows())
633            .expect("Failed to convert to array");
634        let result = as_int32_array(&result)?;
635
636        let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
637
638        assert_eq!(expected, result);
639
640        Ok(())
641    }
642
643    #[test]
644    fn case_with_expr_else() -> Result<()> {
645        let batch = case_test_batch()?;
646        let schema = batch.schema();
647
648        // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 ELSE 999 END
649        let when1 = lit("foo");
650        let then1 = lit(123i32);
651        let when2 = lit("bar");
652        let then2 = lit(456i32);
653        let else_value = lit(999i32);
654
655        let expr = generate_case_when_with_type_coercion(
656            Some(col("a", &schema)?),
657            vec![(when1, then1), (when2, then2)],
658            Some(else_value),
659            schema.as_ref(),
660        )?;
661        let result = expr
662            .evaluate(&batch)?
663            .into_array(batch.num_rows())
664            .expect("Failed to convert to array");
665        let result = as_int32_array(&result)?;
666
667        let expected =
668            &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
669
670        assert_eq!(expected, result);
671
672        Ok(())
673    }
674
675    #[test]
676    fn case_with_expr_divide_by_zero() -> Result<()> {
677        let batch = case_test_batch1()?;
678        let schema = batch.schema();
679
680        // CASE a when 0 THEN float64(null) ELSE 25.0 / cast(a, float64)  END
681        let when1 = lit(0i32);
682        let then1 = lit(ScalarValue::Float64(None));
683        let else_value = binary(
684            lit(25.0f64),
685            Operator::Divide,
686            cast(col("a", &schema)?, &batch.schema(), Float64)?,
687            &batch.schema(),
688        )?;
689
690        let expr = generate_case_when_with_type_coercion(
691            Some(col("a", &schema)?),
692            vec![(when1, then1)],
693            Some(else_value),
694            schema.as_ref(),
695        )?;
696        let result = expr
697            .evaluate(&batch)?
698            .into_array(batch.num_rows())
699            .expect("Failed to convert to array");
700        let result =
701            as_float64_array(&result).expect("failed to downcast to Float64Array");
702
703        let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
704
705        assert_eq!(expected, result);
706
707        Ok(())
708    }
709
710    #[test]
711    fn case_without_expr() -> Result<()> {
712        let batch = case_test_batch()?;
713        let schema = batch.schema();
714
715        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 END
716        let when1 = binary(
717            col("a", &schema)?,
718            Operator::Eq,
719            lit("foo"),
720            &batch.schema(),
721        )?;
722        let then1 = lit(123i32);
723        let when2 = binary(
724            col("a", &schema)?,
725            Operator::Eq,
726            lit("bar"),
727            &batch.schema(),
728        )?;
729        let then2 = lit(456i32);
730
731        let expr = generate_case_when_with_type_coercion(
732            None,
733            vec![(when1, then1), (when2, then2)],
734            None,
735            schema.as_ref(),
736        )?;
737        let result = expr
738            .evaluate(&batch)?
739            .into_array(batch.num_rows())
740            .expect("Failed to convert to array");
741        let result = as_int32_array(&result)?;
742
743        let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
744
745        assert_eq!(expected, result);
746
747        Ok(())
748    }
749
750    #[test]
751    fn case_with_expr_when_null() -> Result<()> {
752        let batch = case_test_batch()?;
753        let schema = batch.schema();
754
755        // CASE a WHEN NULL THEN 0 WHEN a THEN 123 ELSE 999 END
756        let when1 = lit(ScalarValue::Utf8(None));
757        let then1 = lit(0i32);
758        let when2 = col("a", &schema)?;
759        let then2 = lit(123i32);
760        let else_value = lit(999i32);
761
762        let expr = generate_case_when_with_type_coercion(
763            Some(col("a", &schema)?),
764            vec![(when1, then1), (when2, then2)],
765            Some(else_value),
766            schema.as_ref(),
767        )?;
768        let result = expr
769            .evaluate(&batch)?
770            .into_array(batch.num_rows())
771            .expect("Failed to convert to array");
772        let result = as_int32_array(&result)?;
773
774        let expected =
775            &Int32Array::from(vec![Some(123), Some(123), Some(999), Some(123)]);
776
777        assert_eq!(expected, result);
778
779        Ok(())
780    }
781
782    #[test]
783    fn case_without_expr_divide_by_zero() -> Result<()> {
784        let batch = case_test_batch1()?;
785        let schema = batch.schema();
786
787        // CASE WHEN a > 0 THEN 25.0 / cast(a, float64) ELSE float64(null) END
788        let when1 = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &batch.schema())?;
789        let then1 = binary(
790            lit(25.0f64),
791            Operator::Divide,
792            cast(col("a", &schema)?, &batch.schema(), Float64)?,
793            &batch.schema(),
794        )?;
795        let x = lit(ScalarValue::Float64(None));
796
797        let expr = generate_case_when_with_type_coercion(
798            None,
799            vec![(when1, then1)],
800            Some(x),
801            schema.as_ref(),
802        )?;
803        let result = expr
804            .evaluate(&batch)?
805            .into_array(batch.num_rows())
806            .expect("Failed to convert to array");
807        let result =
808            as_float64_array(&result).expect("failed to downcast to Float64Array");
809
810        let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
811
812        assert_eq!(expected, result);
813
814        Ok(())
815    }
816
817    fn case_test_batch1() -> Result<RecordBatch> {
818        let schema = Schema::new(vec![
819            Field::new("a", DataType::Int32, true),
820            Field::new("b", DataType::Int32, true),
821            Field::new("c", DataType::Int32, true),
822        ]);
823        let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]);
824        let b = Int32Array::from(vec![Some(3), None, Some(14), Some(7)]);
825        let c = Int32Array::from(vec![Some(0), Some(-3), Some(777), None]);
826        let batch = RecordBatch::try_new(
827            Arc::new(schema),
828            vec![Arc::new(a), Arc::new(b), Arc::new(c)],
829        )?;
830        Ok(batch)
831    }
832
833    #[test]
834    fn case_without_expr_else() -> Result<()> {
835        let batch = case_test_batch()?;
836        let schema = batch.schema();
837
838        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 999 END
839        let when1 = binary(
840            col("a", &schema)?,
841            Operator::Eq,
842            lit("foo"),
843            &batch.schema(),
844        )?;
845        let then1 = lit(123i32);
846        let when2 = binary(
847            col("a", &schema)?,
848            Operator::Eq,
849            lit("bar"),
850            &batch.schema(),
851        )?;
852        let then2 = lit(456i32);
853        let else_value = lit(999i32);
854
855        let expr = generate_case_when_with_type_coercion(
856            None,
857            vec![(when1, then1), (when2, then2)],
858            Some(else_value),
859            schema.as_ref(),
860        )?;
861        let result = expr
862            .evaluate(&batch)?
863            .into_array(batch.num_rows())
864            .expect("Failed to convert to array");
865        let result = as_int32_array(&result)?;
866
867        let expected =
868            &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
869
870        assert_eq!(expected, result);
871
872        Ok(())
873    }
874
875    #[test]
876    fn case_with_type_cast() -> Result<()> {
877        let batch = case_test_batch()?;
878        let schema = batch.schema();
879
880        // CASE WHEN a = 'foo' THEN 123.3 ELSE 999 END
881        let when = binary(
882            col("a", &schema)?,
883            Operator::Eq,
884            lit("foo"),
885            &batch.schema(),
886        )?;
887        let then = lit(123.3f64);
888        let else_value = lit(999i32);
889
890        let expr = generate_case_when_with_type_coercion(
891            None,
892            vec![(when, then)],
893            Some(else_value),
894            schema.as_ref(),
895        )?;
896        let result = expr
897            .evaluate(&batch)?
898            .into_array(batch.num_rows())
899            .expect("Failed to convert to array");
900        let result =
901            as_float64_array(&result).expect("failed to downcast to Float64Array");
902
903        let expected =
904            &Float64Array::from(vec![Some(123.3), Some(999.0), Some(999.0), Some(999.0)]);
905
906        assert_eq!(expected, result);
907
908        Ok(())
909    }
910
911    #[test]
912    fn case_with_matches_and_nulls() -> Result<()> {
913        let batch = case_test_batch_nulls()?;
914        let schema = batch.schema();
915
916        // SELECT CASE WHEN load4 = 1.77 THEN load4 END
917        let when = binary(
918            col("load4", &schema)?,
919            Operator::Eq,
920            lit(1.77f64),
921            &batch.schema(),
922        )?;
923        let then = col("load4", &schema)?;
924
925        let expr = generate_case_when_with_type_coercion(
926            None,
927            vec![(when, then)],
928            None,
929            schema.as_ref(),
930        )?;
931        let result = expr
932            .evaluate(&batch)?
933            .into_array(batch.num_rows())
934            .expect("Failed to convert to array");
935        let result =
936            as_float64_array(&result).expect("failed to downcast to Float64Array");
937
938        let expected =
939            &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
940
941        assert_eq!(expected, result);
942
943        Ok(())
944    }
945
946    #[test]
947    fn case_with_scalar_predicate() -> Result<()> {
948        let batch = case_test_batch_nulls()?;
949        let schema = batch.schema();
950
951        // SELECT CASE WHEN TRUE THEN load4 END
952        let when = lit(true);
953        let then = col("load4", &schema)?;
954        let expr = generate_case_when_with_type_coercion(
955            None,
956            vec![(when, then)],
957            None,
958            schema.as_ref(),
959        )?;
960
961        // many rows
962        let result = expr
963            .evaluate(&batch)?
964            .into_array(batch.num_rows())
965            .expect("Failed to convert to array");
966        let result =
967            as_float64_array(&result).expect("failed to downcast to Float64Array");
968        let expected = &Float64Array::from(vec![
969            Some(1.77),
970            None,
971            None,
972            Some(1.78),
973            None,
974            Some(1.77),
975        ]);
976        assert_eq!(expected, result);
977
978        // one row
979        let expected = Float64Array::from(vec![Some(1.1)]);
980        let batch =
981            RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(expected.clone())])?;
982        let result = expr
983            .evaluate(&batch)?
984            .into_array(batch.num_rows())
985            .expect("Failed to convert to array");
986        let result =
987            as_float64_array(&result).expect("failed to downcast to Float64Array");
988        assert_eq!(&expected, result);
989
990        Ok(())
991    }
992
993    #[test]
994    fn case_expr_matches_and_nulls() -> Result<()> {
995        let batch = case_test_batch_nulls()?;
996        let schema = batch.schema();
997
998        // SELECT CASE load4 WHEN 1.77 THEN load4 END
999        let expr = col("load4", &schema)?;
1000        let when = lit(1.77f64);
1001        let then = col("load4", &schema)?;
1002
1003        let expr = generate_case_when_with_type_coercion(
1004            Some(expr),
1005            vec![(when, then)],
1006            None,
1007            schema.as_ref(),
1008        )?;
1009        let result = expr
1010            .evaluate(&batch)?
1011            .into_array(batch.num_rows())
1012            .expect("Failed to convert to array");
1013        let result =
1014            as_float64_array(&result).expect("failed to downcast to Float64Array");
1015
1016        let expected =
1017            &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
1018
1019        assert_eq!(expected, result);
1020
1021        Ok(())
1022    }
1023
1024    #[test]
1025    fn test_when_null_and_some_cond_else_null() -> Result<()> {
1026        let batch = case_test_batch()?;
1027        let schema = batch.schema();
1028
1029        let when = binary(
1030            Arc::new(Literal::new(ScalarValue::Boolean(None))),
1031            Operator::And,
1032            binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?,
1033            &schema,
1034        )?;
1035        let then = col("a", &schema)?;
1036
1037        // SELECT CASE WHEN (NULL AND a = 'foo') THEN a ELSE NULL END
1038        let expr = Arc::new(CaseExpr::try_new(None, vec![(when, then)], None)?);
1039        let result = expr
1040            .evaluate(&batch)?
1041            .into_array(batch.num_rows())
1042            .expect("Failed to convert to array");
1043        let result = as_string_array(&result);
1044
1045        // all result values should be null
1046        assert_eq!(result.logical_null_count(), batch.num_rows());
1047        Ok(())
1048    }
1049
1050    fn case_test_batch() -> Result<RecordBatch> {
1051        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
1052        let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
1053        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
1054        Ok(batch)
1055    }
1056
1057    // Construct an array that has several NULL values whose
1058    // underlying buffer actually matches the where expr predicate
1059    fn case_test_batch_nulls() -> Result<RecordBatch> {
1060        let load4: Float64Array = vec![
1061            Some(1.77), // 1.77
1062            Some(1.77), // null <-- same value, but will be set to null
1063            Some(1.77), // null <-- same value, but will be set to null
1064            Some(1.78), // 1.78
1065            None,       // null
1066            Some(1.77), // 1.77
1067        ]
1068        .into_iter()
1069        .collect();
1070
1071        //let valid_array = vec![true, false, false, true, false, tru
1072        let null_buffer = Buffer::from([0b00101001u8]);
1073        let load4 = load4
1074            .into_data()
1075            .into_builder()
1076            .null_bit_buffer(Some(null_buffer))
1077            .build()
1078            .unwrap();
1079        let load4: Float64Array = load4.into();
1080
1081        let batch =
1082            RecordBatch::try_from_iter(vec![("load4", Arc::new(load4) as ArrayRef)])?;
1083        Ok(batch)
1084    }
1085
1086    #[test]
1087    fn case_test_incompatible() -> Result<()> {
1088        // 1 then is int64
1089        // 2 then is boolean
1090        let batch = case_test_batch()?;
1091        let schema = batch.schema();
1092
1093        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN true END
1094        let when1 = binary(
1095            col("a", &schema)?,
1096            Operator::Eq,
1097            lit("foo"),
1098            &batch.schema(),
1099        )?;
1100        let then1 = lit(123i32);
1101        let when2 = binary(
1102            col("a", &schema)?,
1103            Operator::Eq,
1104            lit("bar"),
1105            &batch.schema(),
1106        )?;
1107        let then2 = lit(true);
1108
1109        let expr = generate_case_when_with_type_coercion(
1110            None,
1111            vec![(when1, then1), (when2, then2)],
1112            None,
1113            schema.as_ref(),
1114        );
1115        assert!(expr.is_err());
1116
1117        // then 1 is int32
1118        // then 2 is int64
1119        // else is float
1120        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 1.23 END
1121        let when1 = binary(
1122            col("a", &schema)?,
1123            Operator::Eq,
1124            lit("foo"),
1125            &batch.schema(),
1126        )?;
1127        let then1 = lit(123i32);
1128        let when2 = binary(
1129            col("a", &schema)?,
1130            Operator::Eq,
1131            lit("bar"),
1132            &batch.schema(),
1133        )?;
1134        let then2 = lit(456i64);
1135        let else_expr = lit(1.23f64);
1136
1137        let expr = generate_case_when_with_type_coercion(
1138            None,
1139            vec![(when1, then1), (when2, then2)],
1140            Some(else_expr),
1141            schema.as_ref(),
1142        );
1143        assert!(expr.is_ok());
1144        let result_type = expr.unwrap().data_type(schema.as_ref())?;
1145        assert_eq!(Float64, result_type);
1146        Ok(())
1147    }
1148
1149    #[test]
1150    fn case_eq() -> Result<()> {
1151        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1152
1153        let when1 = lit("foo");
1154        let then1 = lit(123i32);
1155        let when2 = lit("bar");
1156        let then2 = lit(456i32);
1157        let else_value = lit(999i32);
1158
1159        let expr1 = generate_case_when_with_type_coercion(
1160            Some(col("a", &schema)?),
1161            vec![
1162                (Arc::clone(&when1), Arc::clone(&then1)),
1163                (Arc::clone(&when2), Arc::clone(&then2)),
1164            ],
1165            Some(Arc::clone(&else_value)),
1166            &schema,
1167        )?;
1168
1169        let expr2 = generate_case_when_with_type_coercion(
1170            Some(col("a", &schema)?),
1171            vec![
1172                (Arc::clone(&when1), Arc::clone(&then1)),
1173                (Arc::clone(&when2), Arc::clone(&then2)),
1174            ],
1175            Some(Arc::clone(&else_value)),
1176            &schema,
1177        )?;
1178
1179        let expr3 = generate_case_when_with_type_coercion(
1180            Some(col("a", &schema)?),
1181            vec![(Arc::clone(&when1), Arc::clone(&then1)), (when2, then2)],
1182            None,
1183            &schema,
1184        )?;
1185
1186        let expr4 = generate_case_when_with_type_coercion(
1187            Some(col("a", &schema)?),
1188            vec![(when1, then1)],
1189            Some(else_value),
1190            &schema,
1191        )?;
1192
1193        assert!(expr1.eq(&expr2));
1194        assert!(expr2.eq(&expr1));
1195
1196        assert!(expr2.ne(&expr3));
1197        assert!(expr3.ne(&expr2));
1198
1199        assert!(expr1.ne(&expr4));
1200        assert!(expr4.ne(&expr1));
1201
1202        Ok(())
1203    }
1204
1205    #[test]
1206    fn case_transform() -> Result<()> {
1207        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1208
1209        let when1 = lit("foo");
1210        let then1 = lit(123i32);
1211        let when2 = lit("bar");
1212        let then2 = lit(456i32);
1213        let else_value = lit(999i32);
1214
1215        let expr = generate_case_when_with_type_coercion(
1216            Some(col("a", &schema)?),
1217            vec![
1218                (Arc::clone(&when1), Arc::clone(&then1)),
1219                (Arc::clone(&when2), Arc::clone(&then2)),
1220            ],
1221            Some(Arc::clone(&else_value)),
1222            &schema,
1223        )?;
1224
1225        let expr2 = Arc::clone(&expr)
1226            .transform(|e| {
1227                let transformed = match e.as_any().downcast_ref::<Literal>() {
1228                    Some(lit_value) => match lit_value.value() {
1229                        ScalarValue::Utf8(Some(str_value)) => {
1230                            Some(lit(str_value.to_uppercase()))
1231                        }
1232                        _ => None,
1233                    },
1234                    _ => None,
1235                };
1236                Ok(if let Some(transformed) = transformed {
1237                    Transformed::yes(transformed)
1238                } else {
1239                    Transformed::no(e)
1240                })
1241            })
1242            .data()
1243            .unwrap();
1244
1245        let expr3 = Arc::clone(&expr)
1246            .transform_down(|e| {
1247                let transformed = match e.as_any().downcast_ref::<Literal>() {
1248                    Some(lit_value) => match lit_value.value() {
1249                        ScalarValue::Utf8(Some(str_value)) => {
1250                            Some(lit(str_value.to_uppercase()))
1251                        }
1252                        _ => None,
1253                    },
1254                    _ => None,
1255                };
1256                Ok(if let Some(transformed) = transformed {
1257                    Transformed::yes(transformed)
1258                } else {
1259                    Transformed::no(e)
1260                })
1261            })
1262            .data()
1263            .unwrap();
1264
1265        assert!(expr.ne(&expr2));
1266        assert!(expr2.eq(&expr3));
1267
1268        Ok(())
1269    }
1270
1271    #[test]
1272    fn test_column_or_null_specialization() -> Result<()> {
1273        // create input data
1274        let mut c1 = Int32Builder::new();
1275        let mut c2 = StringBuilder::new();
1276        for i in 0..1000 {
1277            c1.append_value(i);
1278            if i % 7 == 0 {
1279                c2.append_null();
1280            } else {
1281                c2.append_value(format!("string {i}"));
1282            }
1283        }
1284        let c1 = Arc::new(c1.finish());
1285        let c2 = Arc::new(c2.finish());
1286        let schema = Schema::new(vec![
1287            Field::new("c1", DataType::Int32, true),
1288            Field::new("c2", DataType::Utf8, true),
1289        ]);
1290        let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap();
1291
1292        // CaseWhenExprOrNull should produce same results as CaseExpr
1293        let predicate = Arc::new(BinaryExpr::new(
1294            make_col("c1", 0),
1295            Operator::LtEq,
1296            make_lit_i32(250),
1297        ));
1298        let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?;
1299        assert!(matches!(expr.eval_method, EvalMethod::InfallibleExprOrNull));
1300        match expr.evaluate(&batch)? {
1301            ColumnarValue::Array(array) => {
1302                assert_eq!(1000, array.len());
1303                assert_eq!(785, array.null_count());
1304            }
1305            _ => unreachable!(),
1306        }
1307        Ok(())
1308    }
1309
1310    #[test]
1311    fn test_expr_or_expr_specialization() -> Result<()> {
1312        let batch = case_test_batch1()?;
1313        let schema = batch.schema();
1314        let when = binary(
1315            col("a", &schema)?,
1316            Operator::LtEq,
1317            lit(2i32),
1318            &batch.schema(),
1319        )?;
1320        let then = col("b", &schema)?;
1321        let else_expr = col("c", &schema)?;
1322        let expr = CaseExpr::try_new(None, vec![(when, then)], Some(else_expr))?;
1323        assert!(matches!(
1324            expr.eval_method,
1325            EvalMethod::ExpressionOrExpression
1326        ));
1327        let result = expr
1328            .evaluate(&batch)?
1329            .into_array(batch.num_rows())
1330            .expect("Failed to convert to array");
1331        let result = as_int32_array(&result).expect("failed to downcast to Int32Array");
1332
1333        let expected = &Int32Array::from(vec![Some(3), None, Some(777), None]);
1334
1335        assert_eq!(expected, result);
1336        Ok(())
1337    }
1338
1339    fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
1340        Arc::new(Column::new(name, index))
1341    }
1342
1343    fn make_lit_i32(n: i32) -> Arc<dyn PhysicalExpr> {
1344        Arc::new(Literal::new(ScalarValue::Int32(Some(n))))
1345    }
1346
1347    fn generate_case_when_with_type_coercion(
1348        expr: Option<Arc<dyn PhysicalExpr>>,
1349        when_thens: Vec<WhenThen>,
1350        else_expr: Option<Arc<dyn PhysicalExpr>>,
1351        input_schema: &Schema,
1352    ) -> Result<Arc<dyn PhysicalExpr>> {
1353        let coerce_type =
1354            get_case_common_type(&when_thens, else_expr.clone(), input_schema);
1355        let (when_thens, else_expr) = match coerce_type {
1356            None => plan_err!(
1357                "Can't get a common type for then {when_thens:?} and else {else_expr:?} expression"
1358            ),
1359            Some(data_type) => {
1360                // cast then expr
1361                let left = when_thens
1362                    .into_iter()
1363                    .map(|(when, then)| {
1364                        let then = try_cast(then, input_schema, data_type.clone())?;
1365                        Ok((when, then))
1366                    })
1367                    .collect::<Result<Vec<_>>>()?;
1368                let right = match else_expr {
1369                    None => None,
1370                    Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?),
1371                };
1372
1373                Ok((left, right))
1374            }
1375        }?;
1376        case(expr, when_thens, else_expr)
1377    }
1378
1379    fn get_case_common_type(
1380        when_thens: &[WhenThen],
1381        else_expr: Option<Arc<dyn PhysicalExpr>>,
1382        input_schema: &Schema,
1383    ) -> Option<DataType> {
1384        let thens_type = when_thens
1385            .iter()
1386            .map(|when_then| {
1387                let data_type = &when_then.1.data_type(input_schema).unwrap();
1388                data_type.clone()
1389            })
1390            .collect::<Vec<_>>();
1391        let else_type = match else_expr {
1392            None => {
1393                // case when then exprs must have one then value
1394                thens_type[0].clone()
1395            }
1396            Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(),
1397        };
1398        thens_type
1399            .iter()
1400            .try_fold(else_type, |left_type, right_type| {
1401                // TODO: now just use the `equal` coercion rule for case when. If find the issue, and
1402                // refactor again.
1403                comparison_coercion(&left_type, right_type)
1404            })
1405    }
1406
1407    #[test]
1408    fn test_fmt_sql() -> Result<()> {
1409        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
1410
1411        // CASE WHEN a = 'foo' THEN 123.3 ELSE 999 END
1412        let when = binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?;
1413        let then = lit(123.3f64);
1414        let else_value = lit(999i32);
1415
1416        let expr = generate_case_when_with_type_coercion(
1417            None,
1418            vec![(when, then)],
1419            Some(else_value),
1420            &schema,
1421        )?;
1422
1423        let display_string = expr.to_string();
1424        assert_eq!(
1425            display_string,
1426            "CASE WHEN a@0 = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END"
1427        );
1428
1429        let sql_string = fmt_sql(expr.as_ref()).to_string();
1430        assert_eq!(
1431            sql_string,
1432            "CASE WHEN a = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END"
1433        );
1434
1435        Ok(())
1436    }
1437}