datafusion_physical_expr/expressions/
case.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::expressions::try_cast;
19use crate::PhysicalExpr;
20use std::borrow::Cow;
21use std::hash::Hash;
22use std::{any::Any, sync::Arc};
23
24use arrow::array::*;
25use arrow::compute::kernels::zip::zip;
26use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter};
27use arrow::datatypes::{DataType, Schema};
28use datafusion_common::cast::as_boolean_array;
29use datafusion_common::{
30    exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
31};
32use datafusion_expr::ColumnarValue;
33
34use super::{Column, Literal};
35use datafusion_physical_expr_common::datum::compare_with_eq;
36use itertools::Itertools;
37
38type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
39
40#[derive(Debug, Hash, PartialEq, Eq)]
41enum EvalMethod {
42    /// CASE WHEN condition THEN result
43    ///      [WHEN ...]
44    ///      [ELSE result]
45    /// END
46    NoExpression,
47    /// CASE expression
48    ///     WHEN value THEN result
49    ///     [WHEN ...]
50    ///     [ELSE result]
51    /// END
52    WithExpression,
53    /// This is a specialization for a specific use case where we can take a fast path
54    /// for expressions that are infallible and can be cheaply computed for the entire
55    /// record batch rather than just for the rows where the predicate is true.
56    ///
57    /// CASE WHEN condition THEN column [ELSE NULL] END
58    InfallibleExprOrNull,
59    /// This is a specialization for a specific use case where we can take a fast path
60    /// if there is just one when/then pair and both the `then` and `else` expressions
61    /// are literal values
62    /// CASE WHEN condition THEN literal ELSE literal END
63    ScalarOrScalar,
64    /// This is a specialization for a specific use case where we can take a fast path
65    /// if there is just one when/then pair and both the `then` and `else` are expressions
66    ///
67    /// CASE WHEN condition THEN expression ELSE expression END
68    ExpressionOrExpression,
69}
70
71/// The CASE expression is similar to a series of nested if/else and there are two forms that
72/// can be used. The first form consists of a series of boolean "when" expressions with
73/// corresponding "then" expressions, and an optional "else" expression.
74///
75/// CASE WHEN condition THEN result
76///      [WHEN ...]
77///      [ELSE result]
78/// END
79///
80/// The second form uses a base expression and then a series of "when" clauses that match on a
81/// literal value.
82///
83/// CASE expression
84///     WHEN value THEN result
85///     [WHEN ...]
86///     [ELSE result]
87/// END
88#[derive(Debug, Hash, PartialEq, Eq)]
89pub struct CaseExpr {
90    /// Optional base expression that can be compared to literal values in the "when" expressions
91    expr: Option<Arc<dyn PhysicalExpr>>,
92    /// One or more when/then expressions
93    when_then_expr: Vec<WhenThen>,
94    /// Optional "else" expression
95    else_expr: Option<Arc<dyn PhysicalExpr>>,
96    /// Evaluation method to use
97    eval_method: EvalMethod,
98}
99
100impl std::fmt::Display for CaseExpr {
101    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
102        write!(f, "CASE ")?;
103        if let Some(e) = &self.expr {
104            write!(f, "{e} ")?;
105        }
106        for (w, t) in &self.when_then_expr {
107            write!(f, "WHEN {w} THEN {t} ")?;
108        }
109        if let Some(e) = &self.else_expr {
110            write!(f, "ELSE {e} ")?;
111        }
112        write!(f, "END")
113    }
114}
115
116/// This is a specialization for a specific use case where we can take a fast path
117/// for expressions that are infallible and can be cheaply computed for the entire
118/// record batch rather than just for the rows where the predicate is true. For now,
119/// this is limited to use with Column expressions but could potentially be used for other
120/// expressions in the future
121fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) -> bool {
122    expr.as_any().is::<Column>()
123}
124
125impl CaseExpr {
126    /// Create a new CASE WHEN expression
127    pub fn try_new(
128        expr: Option<Arc<dyn PhysicalExpr>>,
129        when_then_expr: Vec<WhenThen>,
130        else_expr: Option<Arc<dyn PhysicalExpr>>,
131    ) -> Result<Self> {
132        // normalize null literals to None in the else_expr (this already happens
133        // during SQL planning, but not necessarily for other use cases)
134        let else_expr = match &else_expr {
135            Some(e) => match e.as_any().downcast_ref::<Literal>() {
136                Some(lit) if lit.value().is_null() => None,
137                _ => else_expr,
138            },
139            _ => else_expr,
140        };
141
142        if when_then_expr.is_empty() {
143            exec_err!("There must be at least one WHEN clause")
144        } else {
145            let eval_method = if expr.is_some() {
146                EvalMethod::WithExpression
147            } else if when_then_expr.len() == 1
148                && is_cheap_and_infallible(&(when_then_expr[0].1))
149                && else_expr.is_none()
150            {
151                EvalMethod::InfallibleExprOrNull
152            } else if when_then_expr.len() == 1
153                && when_then_expr[0].1.as_any().is::<Literal>()
154                && else_expr.is_some()
155                && else_expr.as_ref().unwrap().as_any().is::<Literal>()
156            {
157                EvalMethod::ScalarOrScalar
158            } else if when_then_expr.len() == 1
159                && is_cheap_and_infallible(&(when_then_expr[0].1))
160                && else_expr.as_ref().is_some_and(is_cheap_and_infallible)
161            {
162                EvalMethod::ExpressionOrExpression
163            } else {
164                EvalMethod::NoExpression
165            };
166
167            Ok(Self {
168                expr,
169                when_then_expr,
170                else_expr,
171                eval_method,
172            })
173        }
174    }
175
176    /// Optional base expression that can be compared to literal values in the "when" expressions
177    pub fn expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
178        self.expr.as_ref()
179    }
180
181    /// One or more when/then expressions
182    pub fn when_then_expr(&self) -> &[WhenThen] {
183        &self.when_then_expr
184    }
185
186    /// Optional "else" expression
187    pub fn else_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
188        self.else_expr.as_ref()
189    }
190}
191
192impl CaseExpr {
193    /// This function evaluates the form of CASE that matches an expression to fixed values.
194    ///
195    /// CASE expression
196    ///     WHEN value THEN result
197    ///     [WHEN ...]
198    ///     [ELSE result]
199    /// END
200    fn case_when_with_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
201        let return_type = self.data_type(&batch.schema())?;
202        let expr = self.expr.as_ref().unwrap();
203        let base_value = expr.evaluate(batch)?;
204        let base_value = base_value.into_array(batch.num_rows())?;
205        let base_nulls = is_null(base_value.as_ref())?;
206
207        // start with nulls as default output
208        let mut current_value = new_null_array(&return_type, batch.num_rows());
209        // We only consider non-null values while comparing with whens
210        let mut remainder = not(&base_nulls)?;
211        for i in 0..self.when_then_expr.len() {
212            let when_value = self.when_then_expr[i]
213                .0
214                .evaluate_selection(batch, &remainder)?;
215            let when_value = when_value.into_array(batch.num_rows())?;
216            // build boolean array representing which rows match the "when" value
217            let when_match = compare_with_eq(
218                &when_value,
219                &base_value,
220                // The types of case and when expressions will be coerced to match.
221                // We only need to check if the base_value is nested.
222                base_value.data_type().is_nested(),
223            )?;
224            // Treat nulls as false
225            let when_match = match when_match.null_count() {
226                0 => Cow::Borrowed(&when_match),
227                _ => Cow::Owned(prep_null_mask_filter(&when_match)),
228            };
229            // Make sure we only consider rows that have not been matched yet
230            let when_match = and(&when_match, &remainder)?;
231
232            // When no rows available for when clause, skip then clause
233            if when_match.true_count() == 0 {
234                continue;
235            }
236
237            let then_value = self.when_then_expr[i]
238                .1
239                .evaluate_selection(batch, &when_match)?;
240
241            current_value = match then_value {
242                ColumnarValue::Scalar(ScalarValue::Null) => {
243                    nullif(current_value.as_ref(), &when_match)?
244                }
245                ColumnarValue::Scalar(then_value) => {
246                    zip(&when_match, &then_value.to_scalar()?, &current_value)?
247                }
248                ColumnarValue::Array(then_value) => {
249                    zip(&when_match, &then_value, &current_value)?
250                }
251            };
252
253            remainder = and_not(&remainder, &when_match)?;
254        }
255
256        if let Some(e) = self.else_expr() {
257            // keep `else_expr`'s data type and return type consistent
258            let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
259            // null and unmatched tuples should be assigned else value
260            remainder = or(&base_nulls, &remainder)?;
261            let else_ = expr
262                .evaluate_selection(batch, &remainder)?
263                .into_array(batch.num_rows())?;
264            current_value = zip(&remainder, &else_, &current_value)?;
265        }
266
267        Ok(ColumnarValue::Array(current_value))
268    }
269
270    /// This function evaluates the form of CASE where each WHEN expression is a boolean
271    /// expression.
272    ///
273    /// CASE WHEN condition THEN result
274    ///      [WHEN ...]
275    ///      [ELSE result]
276    /// END
277    fn case_when_no_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
278        let return_type = self.data_type(&batch.schema())?;
279
280        // start with nulls as default output
281        let mut current_value = new_null_array(&return_type, batch.num_rows());
282        let mut remainder = BooleanArray::from(vec![true; batch.num_rows()]);
283        for i in 0..self.when_then_expr.len() {
284            let when_value = self.when_then_expr[i]
285                .0
286                .evaluate_selection(batch, &remainder)?;
287            let when_value = when_value.into_array(batch.num_rows())?;
288            let when_value = as_boolean_array(&when_value).map_err(|_| {
289                internal_datafusion_err!("WHEN expression did not return a BooleanArray")
290            })?;
291            // Treat 'NULL' as false value
292            let when_value = match when_value.null_count() {
293                0 => Cow::Borrowed(when_value),
294                _ => Cow::Owned(prep_null_mask_filter(when_value)),
295            };
296            // Make sure we only consider rows that have not been matched yet
297            let when_value = and(&when_value, &remainder)?;
298
299            // When no rows available for when clause, skip then clause
300            if when_value.true_count() == 0 {
301                continue;
302            }
303
304            let then_value = self.when_then_expr[i]
305                .1
306                .evaluate_selection(batch, &when_value)?;
307
308            current_value = match then_value {
309                ColumnarValue::Scalar(ScalarValue::Null) => {
310                    nullif(current_value.as_ref(), &when_value)?
311                }
312                ColumnarValue::Scalar(then_value) => {
313                    zip(&when_value, &then_value.to_scalar()?, &current_value)?
314                }
315                ColumnarValue::Array(then_value) => {
316                    zip(&when_value, &then_value, &current_value)?
317                }
318            };
319
320            // Succeed tuples should be filtered out for short-circuit evaluation,
321            // null values for the current when expr should be kept
322            remainder = and_not(&remainder, &when_value)?;
323        }
324
325        if let Some(e) = self.else_expr() {
326            if remainder.true_count() > 0 {
327                // keep `else_expr`'s data type and return type consistent
328                let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
329                let else_ = expr
330                    .evaluate_selection(batch, &remainder)?
331                    .into_array(batch.num_rows())?;
332                current_value = zip(&remainder, &else_, &current_value)?;
333            }
334        }
335
336        Ok(ColumnarValue::Array(current_value))
337    }
338
339    /// This function evaluates the specialized case of:
340    ///
341    /// CASE WHEN condition THEN column
342    ///      [ELSE NULL]
343    /// END
344    ///
345    /// Note that this function is only safe to use for "then" expressions
346    /// that are infallible because the expression will be evaluated for all
347    /// rows in the input batch.
348    fn case_column_or_null(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
349        let when_expr = &self.when_then_expr[0].0;
350        let then_expr = &self.when_then_expr[0].1;
351
352        match when_expr.evaluate(batch)? {
353            // WHEN true --> column
354            ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))) => {
355                then_expr.evaluate(batch)
356            }
357            // WHEN [false | null] --> NULL
358            ColumnarValue::Scalar(_) => {
359                // return scalar NULL value
360                ScalarValue::try_from(self.data_type(&batch.schema())?)
361                    .map(ColumnarValue::Scalar)
362            }
363            // WHEN column --> column
364            ColumnarValue::Array(bit_mask) => {
365                let bit_mask = bit_mask
366                    .as_any()
367                    .downcast_ref::<BooleanArray>()
368                    .expect("predicate should evaluate to a boolean array");
369                // invert the bitmask
370                let bit_mask = match bit_mask.null_count() {
371                    0 => not(bit_mask)?,
372                    _ => not(&prep_null_mask_filter(bit_mask))?,
373                };
374                match then_expr.evaluate(batch)? {
375                    ColumnarValue::Array(array) => {
376                        Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?))
377                    }
378                    ColumnarValue::Scalar(_) => {
379                        internal_err!("expression did not evaluate to an array")
380                    }
381                }
382            }
383        }
384    }
385
386    fn scalar_or_scalar(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
387        let return_type = self.data_type(&batch.schema())?;
388
389        // evaluate when expression
390        let when_value = self.when_then_expr[0].0.evaluate(batch)?;
391        let when_value = when_value.into_array(batch.num_rows())?;
392        let when_value = as_boolean_array(&when_value).map_err(|_| {
393            internal_datafusion_err!("WHEN expression did not return a BooleanArray")
394        })?;
395
396        // Treat 'NULL' as false value
397        let when_value = match when_value.null_count() {
398            0 => Cow::Borrowed(when_value),
399            _ => Cow::Owned(prep_null_mask_filter(when_value)),
400        };
401
402        // evaluate then_value
403        let then_value = self.when_then_expr[0].1.evaluate(batch)?;
404        let then_value = Scalar::new(then_value.into_array(1)?);
405
406        let Some(e) = self.else_expr() else {
407            return internal_err!("expression did not evaluate to an array");
408        };
409        // keep `else_expr`'s data type and return type consistent
410        let expr = try_cast(Arc::clone(e), &batch.schema(), return_type)?;
411        let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?);
412        Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?))
413    }
414
415    fn expr_or_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
416        let return_type = self.data_type(&batch.schema())?;
417
418        // evaluate when condition on batch
419        let when_value = self.when_then_expr[0].0.evaluate(batch)?;
420        let when_value = when_value.into_array(batch.num_rows())?;
421        let when_value = as_boolean_array(&when_value).map_err(|e| {
422            DataFusionError::Context(
423                "WHEN expression did not return a BooleanArray".to_string(),
424                Box::new(e),
425            )
426        })?;
427
428        // Treat 'NULL' as false value
429        let when_value = match when_value.null_count() {
430            0 => Cow::Borrowed(when_value),
431            _ => Cow::Owned(prep_null_mask_filter(when_value)),
432        };
433
434        let then_value = self.when_then_expr[0]
435            .1
436            .evaluate_selection(batch, &when_value)?
437            .into_array(batch.num_rows())?;
438
439        // evaluate else expression on the values not covered by when_value
440        let remainder = not(&when_value)?;
441        let e = self.else_expr.as_ref().unwrap();
442        // keep `else_expr`'s data type and return type consistent
443        let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())
444            .unwrap_or_else(|_| Arc::clone(e));
445        let else_ = expr
446            .evaluate_selection(batch, &remainder)?
447            .into_array(batch.num_rows())?;
448
449        Ok(ColumnarValue::Array(zip(&remainder, &else_, &then_value)?))
450    }
451}
452
453impl PhysicalExpr for CaseExpr {
454    /// Return a reference to Any that can be used for down-casting
455    fn as_any(&self) -> &dyn Any {
456        self
457    }
458
459    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
460        // since all then results have the same data type, we can choose any one as the
461        // return data type except for the null.
462        let mut data_type = DataType::Null;
463        for i in 0..self.when_then_expr.len() {
464            data_type = self.when_then_expr[i].1.data_type(input_schema)?;
465            if !data_type.equals_datatype(&DataType::Null) {
466                break;
467            }
468        }
469        // if all then results are null, we use data type of else expr instead if possible.
470        if data_type.equals_datatype(&DataType::Null) {
471            if let Some(e) = &self.else_expr {
472                data_type = e.data_type(input_schema)?;
473            }
474        }
475
476        Ok(data_type)
477    }
478
479    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
480        // this expression is nullable if any of the input expressions are nullable
481        let then_nullable = self
482            .when_then_expr
483            .iter()
484            .map(|(_, t)| t.nullable(input_schema))
485            .collect::<Result<Vec<_>>>()?;
486        if then_nullable.contains(&true) {
487            Ok(true)
488        } else if let Some(e) = &self.else_expr {
489            e.nullable(input_schema)
490        } else {
491            // CASE produces NULL if there is no `else` expr
492            // (aka when none of the `when_then_exprs` match)
493            Ok(true)
494        }
495    }
496
497    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
498        match self.eval_method {
499            EvalMethod::WithExpression => {
500                // this use case evaluates "expr" and then compares the values with the "when"
501                // values
502                self.case_when_with_expr(batch)
503            }
504            EvalMethod::NoExpression => {
505                // The "when" conditions all evaluate to boolean in this use case and can be
506                // arbitrary expressions
507                self.case_when_no_expr(batch)
508            }
509            EvalMethod::InfallibleExprOrNull => {
510                // Specialization for CASE WHEN expr THEN column [ELSE NULL] END
511                self.case_column_or_null(batch)
512            }
513            EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch),
514            EvalMethod::ExpressionOrExpression => self.expr_or_expr(batch),
515        }
516    }
517
518    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
519        let mut children = vec![];
520        if let Some(expr) = &self.expr {
521            children.push(expr)
522        }
523        self.when_then_expr.iter().for_each(|(cond, value)| {
524            children.push(cond);
525            children.push(value);
526        });
527
528        if let Some(else_expr) = &self.else_expr {
529            children.push(else_expr)
530        }
531        children
532    }
533
534    // For physical CaseExpr, we do not allow modifying children size
535    fn with_new_children(
536        self: Arc<Self>,
537        children: Vec<Arc<dyn PhysicalExpr>>,
538    ) -> Result<Arc<dyn PhysicalExpr>> {
539        if children.len() != self.children().len() {
540            internal_err!("CaseExpr: Wrong number of children")
541        } else {
542            let (expr, when_then_expr, else_expr) =
543                match (self.expr().is_some(), self.else_expr().is_some()) {
544                    (true, true) => (
545                        Some(&children[0]),
546                        &children[1..children.len() - 1],
547                        Some(&children[children.len() - 1]),
548                    ),
549                    (true, false) => {
550                        (Some(&children[0]), &children[1..children.len()], None)
551                    }
552                    (false, true) => (
553                        None,
554                        &children[0..children.len() - 1],
555                        Some(&children[children.len() - 1]),
556                    ),
557                    (false, false) => (None, &children[0..children.len()], None),
558                };
559            Ok(Arc::new(CaseExpr::try_new(
560                expr.cloned(),
561                when_then_expr.iter().cloned().tuples().collect(),
562                else_expr.cloned(),
563            )?))
564        }
565    }
566
567    fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
568        write!(f, "CASE ")?;
569        if let Some(e) = &self.expr {
570            e.fmt_sql(f)?;
571            write!(f, " ")?;
572        }
573
574        for (w, t) in &self.when_then_expr {
575            write!(f, "WHEN ")?;
576            w.fmt_sql(f)?;
577            write!(f, " THEN ")?;
578            t.fmt_sql(f)?;
579            write!(f, " ")?;
580        }
581
582        if let Some(e) = &self.else_expr {
583            write!(f, "ELSE ")?;
584            e.fmt_sql(f)?;
585            write!(f, " ")?;
586        }
587        write!(f, "END")
588    }
589}
590
591/// Create a CASE expression
592pub fn case(
593    expr: Option<Arc<dyn PhysicalExpr>>,
594    when_thens: Vec<WhenThen>,
595    else_expr: Option<Arc<dyn PhysicalExpr>>,
596) -> Result<Arc<dyn PhysicalExpr>> {
597    Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?))
598}
599
600#[cfg(test)]
601mod tests {
602    use super::*;
603
604    use crate::expressions::{binary, cast, col, lit, BinaryExpr};
605    use arrow::buffer::Buffer;
606    use arrow::datatypes::DataType::Float64;
607    use arrow::datatypes::Field;
608    use datafusion_common::cast::{as_float64_array, as_int32_array};
609    use datafusion_common::plan_err;
610    use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
611    use datafusion_expr::type_coercion::binary::comparison_coercion;
612    use datafusion_expr::Operator;
613    use datafusion_physical_expr_common::physical_expr::fmt_sql;
614
615    #[test]
616    fn case_with_expr() -> Result<()> {
617        let batch = case_test_batch()?;
618        let schema = batch.schema();
619
620        // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END
621        let when1 = lit("foo");
622        let then1 = lit(123i32);
623        let when2 = lit("bar");
624        let then2 = lit(456i32);
625
626        let expr = generate_case_when_with_type_coercion(
627            Some(col("a", &schema)?),
628            vec![(when1, then1), (when2, then2)],
629            None,
630            schema.as_ref(),
631        )?;
632        let result = expr
633            .evaluate(&batch)?
634            .into_array(batch.num_rows())
635            .expect("Failed to convert to array");
636        let result = as_int32_array(&result)?;
637
638        let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
639
640        assert_eq!(expected, result);
641
642        Ok(())
643    }
644
645    #[test]
646    fn case_with_expr_else() -> Result<()> {
647        let batch = case_test_batch()?;
648        let schema = batch.schema();
649
650        // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 ELSE 999 END
651        let when1 = lit("foo");
652        let then1 = lit(123i32);
653        let when2 = lit("bar");
654        let then2 = lit(456i32);
655        let else_value = lit(999i32);
656
657        let expr = generate_case_when_with_type_coercion(
658            Some(col("a", &schema)?),
659            vec![(when1, then1), (when2, then2)],
660            Some(else_value),
661            schema.as_ref(),
662        )?;
663        let result = expr
664            .evaluate(&batch)?
665            .into_array(batch.num_rows())
666            .expect("Failed to convert to array");
667        let result = as_int32_array(&result)?;
668
669        let expected =
670            &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
671
672        assert_eq!(expected, result);
673
674        Ok(())
675    }
676
677    #[test]
678    fn case_with_expr_divide_by_zero() -> Result<()> {
679        let batch = case_test_batch1()?;
680        let schema = batch.schema();
681
682        // CASE a when 0 THEN float64(null) ELSE 25.0 / cast(a, float64)  END
683        let when1 = lit(0i32);
684        let then1 = lit(ScalarValue::Float64(None));
685        let else_value = binary(
686            lit(25.0f64),
687            Operator::Divide,
688            cast(col("a", &schema)?, &batch.schema(), Float64)?,
689            &batch.schema(),
690        )?;
691
692        let expr = generate_case_when_with_type_coercion(
693            Some(col("a", &schema)?),
694            vec![(when1, then1)],
695            Some(else_value),
696            schema.as_ref(),
697        )?;
698        let result = expr
699            .evaluate(&batch)?
700            .into_array(batch.num_rows())
701            .expect("Failed to convert to array");
702        let result =
703            as_float64_array(&result).expect("failed to downcast to Float64Array");
704
705        let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
706
707        assert_eq!(expected, result);
708
709        Ok(())
710    }
711
712    #[test]
713    fn case_without_expr() -> Result<()> {
714        let batch = case_test_batch()?;
715        let schema = batch.schema();
716
717        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 END
718        let when1 = binary(
719            col("a", &schema)?,
720            Operator::Eq,
721            lit("foo"),
722            &batch.schema(),
723        )?;
724        let then1 = lit(123i32);
725        let when2 = binary(
726            col("a", &schema)?,
727            Operator::Eq,
728            lit("bar"),
729            &batch.schema(),
730        )?;
731        let then2 = lit(456i32);
732
733        let expr = generate_case_when_with_type_coercion(
734            None,
735            vec![(when1, then1), (when2, then2)],
736            None,
737            schema.as_ref(),
738        )?;
739        let result = expr
740            .evaluate(&batch)?
741            .into_array(batch.num_rows())
742            .expect("Failed to convert to array");
743        let result = as_int32_array(&result)?;
744
745        let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
746
747        assert_eq!(expected, result);
748
749        Ok(())
750    }
751
752    #[test]
753    fn case_with_expr_when_null() -> Result<()> {
754        let batch = case_test_batch()?;
755        let schema = batch.schema();
756
757        // CASE a WHEN NULL THEN 0 WHEN a THEN 123 ELSE 999 END
758        let when1 = lit(ScalarValue::Utf8(None));
759        let then1 = lit(0i32);
760        let when2 = col("a", &schema)?;
761        let then2 = lit(123i32);
762        let else_value = lit(999i32);
763
764        let expr = generate_case_when_with_type_coercion(
765            Some(col("a", &schema)?),
766            vec![(when1, then1), (when2, then2)],
767            Some(else_value),
768            schema.as_ref(),
769        )?;
770        let result = expr
771            .evaluate(&batch)?
772            .into_array(batch.num_rows())
773            .expect("Failed to convert to array");
774        let result = as_int32_array(&result)?;
775
776        let expected =
777            &Int32Array::from(vec![Some(123), Some(123), Some(999), Some(123)]);
778
779        assert_eq!(expected, result);
780
781        Ok(())
782    }
783
784    #[test]
785    fn case_without_expr_divide_by_zero() -> Result<()> {
786        let batch = case_test_batch1()?;
787        let schema = batch.schema();
788
789        // CASE WHEN a > 0 THEN 25.0 / cast(a, float64) ELSE float64(null) END
790        let when1 = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &batch.schema())?;
791        let then1 = binary(
792            lit(25.0f64),
793            Operator::Divide,
794            cast(col("a", &schema)?, &batch.schema(), Float64)?,
795            &batch.schema(),
796        )?;
797        let x = lit(ScalarValue::Float64(None));
798
799        let expr = generate_case_when_with_type_coercion(
800            None,
801            vec![(when1, then1)],
802            Some(x),
803            schema.as_ref(),
804        )?;
805        let result = expr
806            .evaluate(&batch)?
807            .into_array(batch.num_rows())
808            .expect("Failed to convert to array");
809        let result =
810            as_float64_array(&result).expect("failed to downcast to Float64Array");
811
812        let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
813
814        assert_eq!(expected, result);
815
816        Ok(())
817    }
818
819    fn case_test_batch1() -> Result<RecordBatch> {
820        let schema = Schema::new(vec![
821            Field::new("a", DataType::Int32, true),
822            Field::new("b", DataType::Int32, true),
823            Field::new("c", DataType::Int32, true),
824        ]);
825        let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]);
826        let b = Int32Array::from(vec![Some(3), None, Some(14), Some(7)]);
827        let c = Int32Array::from(vec![Some(0), Some(-3), Some(777), None]);
828        let batch = RecordBatch::try_new(
829            Arc::new(schema),
830            vec![Arc::new(a), Arc::new(b), Arc::new(c)],
831        )?;
832        Ok(batch)
833    }
834
835    #[test]
836    fn case_without_expr_else() -> Result<()> {
837        let batch = case_test_batch()?;
838        let schema = batch.schema();
839
840        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 999 END
841        let when1 = binary(
842            col("a", &schema)?,
843            Operator::Eq,
844            lit("foo"),
845            &batch.schema(),
846        )?;
847        let then1 = lit(123i32);
848        let when2 = binary(
849            col("a", &schema)?,
850            Operator::Eq,
851            lit("bar"),
852            &batch.schema(),
853        )?;
854        let then2 = lit(456i32);
855        let else_value = lit(999i32);
856
857        let expr = generate_case_when_with_type_coercion(
858            None,
859            vec![(when1, then1), (when2, then2)],
860            Some(else_value),
861            schema.as_ref(),
862        )?;
863        let result = expr
864            .evaluate(&batch)?
865            .into_array(batch.num_rows())
866            .expect("Failed to convert to array");
867        let result = as_int32_array(&result)?;
868
869        let expected =
870            &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
871
872        assert_eq!(expected, result);
873
874        Ok(())
875    }
876
877    #[test]
878    fn case_with_type_cast() -> Result<()> {
879        let batch = case_test_batch()?;
880        let schema = batch.schema();
881
882        // CASE WHEN a = 'foo' THEN 123.3 ELSE 999 END
883        let when = binary(
884            col("a", &schema)?,
885            Operator::Eq,
886            lit("foo"),
887            &batch.schema(),
888        )?;
889        let then = lit(123.3f64);
890        let else_value = lit(999i32);
891
892        let expr = generate_case_when_with_type_coercion(
893            None,
894            vec![(when, then)],
895            Some(else_value),
896            schema.as_ref(),
897        )?;
898        let result = expr
899            .evaluate(&batch)?
900            .into_array(batch.num_rows())
901            .expect("Failed to convert to array");
902        let result =
903            as_float64_array(&result).expect("failed to downcast to Float64Array");
904
905        let expected =
906            &Float64Array::from(vec![Some(123.3), Some(999.0), Some(999.0), Some(999.0)]);
907
908        assert_eq!(expected, result);
909
910        Ok(())
911    }
912
913    #[test]
914    fn case_with_matches_and_nulls() -> Result<()> {
915        let batch = case_test_batch_nulls()?;
916        let schema = batch.schema();
917
918        // SELECT CASE WHEN load4 = 1.77 THEN load4 END
919        let when = binary(
920            col("load4", &schema)?,
921            Operator::Eq,
922            lit(1.77f64),
923            &batch.schema(),
924        )?;
925        let then = col("load4", &schema)?;
926
927        let expr = generate_case_when_with_type_coercion(
928            None,
929            vec![(when, then)],
930            None,
931            schema.as_ref(),
932        )?;
933        let result = expr
934            .evaluate(&batch)?
935            .into_array(batch.num_rows())
936            .expect("Failed to convert to array");
937        let result =
938            as_float64_array(&result).expect("failed to downcast to Float64Array");
939
940        let expected =
941            &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
942
943        assert_eq!(expected, result);
944
945        Ok(())
946    }
947
948    #[test]
949    fn case_with_scalar_predicate() -> Result<()> {
950        let batch = case_test_batch_nulls()?;
951        let schema = batch.schema();
952
953        // SELECT CASE WHEN TRUE THEN load4 END
954        let when = lit(true);
955        let then = col("load4", &schema)?;
956        let expr = generate_case_when_with_type_coercion(
957            None,
958            vec![(when, then)],
959            None,
960            schema.as_ref(),
961        )?;
962
963        // many rows
964        let result = expr
965            .evaluate(&batch)?
966            .into_array(batch.num_rows())
967            .expect("Failed to convert to array");
968        let result =
969            as_float64_array(&result).expect("failed to downcast to Float64Array");
970        let expected = &Float64Array::from(vec![
971            Some(1.77),
972            None,
973            None,
974            Some(1.78),
975            None,
976            Some(1.77),
977        ]);
978        assert_eq!(expected, result);
979
980        // one row
981        let expected = Float64Array::from(vec![Some(1.1)]);
982        let batch =
983            RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(expected.clone())])?;
984        let result = expr
985            .evaluate(&batch)?
986            .into_array(batch.num_rows())
987            .expect("Failed to convert to array");
988        let result =
989            as_float64_array(&result).expect("failed to downcast to Float64Array");
990        assert_eq!(&expected, result);
991
992        Ok(())
993    }
994
995    #[test]
996    fn case_expr_matches_and_nulls() -> Result<()> {
997        let batch = case_test_batch_nulls()?;
998        let schema = batch.schema();
999
1000        // SELECT CASE load4 WHEN 1.77 THEN load4 END
1001        let expr = col("load4", &schema)?;
1002        let when = lit(1.77f64);
1003        let then = col("load4", &schema)?;
1004
1005        let expr = generate_case_when_with_type_coercion(
1006            Some(expr),
1007            vec![(when, then)],
1008            None,
1009            schema.as_ref(),
1010        )?;
1011        let result = expr
1012            .evaluate(&batch)?
1013            .into_array(batch.num_rows())
1014            .expect("Failed to convert to array");
1015        let result =
1016            as_float64_array(&result).expect("failed to downcast to Float64Array");
1017
1018        let expected =
1019            &Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
1020
1021        assert_eq!(expected, result);
1022
1023        Ok(())
1024    }
1025
1026    #[test]
1027    fn test_when_null_and_some_cond_else_null() -> Result<()> {
1028        let batch = case_test_batch()?;
1029        let schema = batch.schema();
1030
1031        let when = binary(
1032            Arc::new(Literal::new(ScalarValue::Boolean(None))),
1033            Operator::And,
1034            binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?,
1035            &schema,
1036        )?;
1037        let then = col("a", &schema)?;
1038
1039        // SELECT CASE WHEN (NULL AND a = 'foo') THEN a ELSE NULL END
1040        let expr = Arc::new(CaseExpr::try_new(None, vec![(when, then)], None)?);
1041        let result = expr
1042            .evaluate(&batch)?
1043            .into_array(batch.num_rows())
1044            .expect("Failed to convert to array");
1045        let result = as_string_array(&result);
1046
1047        // all result values should be null
1048        assert_eq!(result.logical_null_count(), batch.num_rows());
1049        Ok(())
1050    }
1051
1052    fn case_test_batch() -> Result<RecordBatch> {
1053        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
1054        let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
1055        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
1056        Ok(batch)
1057    }
1058
1059    // Construct an array that has several NULL values whose
1060    // underlying buffer actually matches the where expr predicate
1061    fn case_test_batch_nulls() -> Result<RecordBatch> {
1062        let load4: Float64Array = vec![
1063            Some(1.77), // 1.77
1064            Some(1.77), // null <-- same value, but will be set to null
1065            Some(1.77), // null <-- same value, but will be set to null
1066            Some(1.78), // 1.78
1067            None,       // null
1068            Some(1.77), // 1.77
1069        ]
1070        .into_iter()
1071        .collect();
1072
1073        let null_buffer = Buffer::from([0b00101001u8]);
1074        let load4 = load4
1075            .into_data()
1076            .into_builder()
1077            .null_bit_buffer(Some(null_buffer))
1078            .build()
1079            .unwrap();
1080        let load4: Float64Array = load4.into();
1081
1082        let batch =
1083            RecordBatch::try_from_iter(vec![("load4", Arc::new(load4) as ArrayRef)])?;
1084        Ok(batch)
1085    }
1086
1087    #[test]
1088    fn case_test_incompatible() -> Result<()> {
1089        // 1 then is int64
1090        // 2 then is boolean
1091        let batch = case_test_batch()?;
1092        let schema = batch.schema();
1093
1094        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN true END
1095        let when1 = binary(
1096            col("a", &schema)?,
1097            Operator::Eq,
1098            lit("foo"),
1099            &batch.schema(),
1100        )?;
1101        let then1 = lit(123i32);
1102        let when2 = binary(
1103            col("a", &schema)?,
1104            Operator::Eq,
1105            lit("bar"),
1106            &batch.schema(),
1107        )?;
1108        let then2 = lit(true);
1109
1110        let expr = generate_case_when_with_type_coercion(
1111            None,
1112            vec![(when1, then1), (when2, then2)],
1113            None,
1114            schema.as_ref(),
1115        );
1116        assert!(expr.is_err());
1117
1118        // then 1 is int32
1119        // then 2 is int64
1120        // else is float
1121        // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 1.23 END
1122        let when1 = binary(
1123            col("a", &schema)?,
1124            Operator::Eq,
1125            lit("foo"),
1126            &batch.schema(),
1127        )?;
1128        let then1 = lit(123i32);
1129        let when2 = binary(
1130            col("a", &schema)?,
1131            Operator::Eq,
1132            lit("bar"),
1133            &batch.schema(),
1134        )?;
1135        let then2 = lit(456i64);
1136        let else_expr = lit(1.23f64);
1137
1138        let expr = generate_case_when_with_type_coercion(
1139            None,
1140            vec![(when1, then1), (when2, then2)],
1141            Some(else_expr),
1142            schema.as_ref(),
1143        );
1144        assert!(expr.is_ok());
1145        let result_type = expr.unwrap().data_type(schema.as_ref())?;
1146        assert_eq!(Float64, result_type);
1147        Ok(())
1148    }
1149
1150    #[test]
1151    fn case_eq() -> Result<()> {
1152        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1153
1154        let when1 = lit("foo");
1155        let then1 = lit(123i32);
1156        let when2 = lit("bar");
1157        let then2 = lit(456i32);
1158        let else_value = lit(999i32);
1159
1160        let expr1 = generate_case_when_with_type_coercion(
1161            Some(col("a", &schema)?),
1162            vec![
1163                (Arc::clone(&when1), Arc::clone(&then1)),
1164                (Arc::clone(&when2), Arc::clone(&then2)),
1165            ],
1166            Some(Arc::clone(&else_value)),
1167            &schema,
1168        )?;
1169
1170        let expr2 = generate_case_when_with_type_coercion(
1171            Some(col("a", &schema)?),
1172            vec![
1173                (Arc::clone(&when1), Arc::clone(&then1)),
1174                (Arc::clone(&when2), Arc::clone(&then2)),
1175            ],
1176            Some(Arc::clone(&else_value)),
1177            &schema,
1178        )?;
1179
1180        let expr3 = generate_case_when_with_type_coercion(
1181            Some(col("a", &schema)?),
1182            vec![(Arc::clone(&when1), Arc::clone(&then1)), (when2, then2)],
1183            None,
1184            &schema,
1185        )?;
1186
1187        let expr4 = generate_case_when_with_type_coercion(
1188            Some(col("a", &schema)?),
1189            vec![(when1, then1)],
1190            Some(else_value),
1191            &schema,
1192        )?;
1193
1194        assert!(expr1.eq(&expr2));
1195        assert!(expr2.eq(&expr1));
1196
1197        assert!(expr2.ne(&expr3));
1198        assert!(expr3.ne(&expr2));
1199
1200        assert!(expr1.ne(&expr4));
1201        assert!(expr4.ne(&expr1));
1202
1203        Ok(())
1204    }
1205
1206    #[test]
1207    fn case_transform() -> Result<()> {
1208        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1209
1210        let when1 = lit("foo");
1211        let then1 = lit(123i32);
1212        let when2 = lit("bar");
1213        let then2 = lit(456i32);
1214        let else_value = lit(999i32);
1215
1216        let expr = generate_case_when_with_type_coercion(
1217            Some(col("a", &schema)?),
1218            vec![
1219                (Arc::clone(&when1), Arc::clone(&then1)),
1220                (Arc::clone(&when2), Arc::clone(&then2)),
1221            ],
1222            Some(Arc::clone(&else_value)),
1223            &schema,
1224        )?;
1225
1226        let expr2 = Arc::clone(&expr)
1227            .transform(|e| {
1228                let transformed = match e.as_any().downcast_ref::<Literal>() {
1229                    Some(lit_value) => match lit_value.value() {
1230                        ScalarValue::Utf8(Some(str_value)) => {
1231                            Some(lit(str_value.to_uppercase()))
1232                        }
1233                        _ => None,
1234                    },
1235                    _ => None,
1236                };
1237                Ok(if let Some(transformed) = transformed {
1238                    Transformed::yes(transformed)
1239                } else {
1240                    Transformed::no(e)
1241                })
1242            })
1243            .data()
1244            .unwrap();
1245
1246        let expr3 = Arc::clone(&expr)
1247            .transform_down(|e| {
1248                let transformed = match e.as_any().downcast_ref::<Literal>() {
1249                    Some(lit_value) => match lit_value.value() {
1250                        ScalarValue::Utf8(Some(str_value)) => {
1251                            Some(lit(str_value.to_uppercase()))
1252                        }
1253                        _ => None,
1254                    },
1255                    _ => None,
1256                };
1257                Ok(if let Some(transformed) = transformed {
1258                    Transformed::yes(transformed)
1259                } else {
1260                    Transformed::no(e)
1261                })
1262            })
1263            .data()
1264            .unwrap();
1265
1266        assert!(expr.ne(&expr2));
1267        assert!(expr2.eq(&expr3));
1268
1269        Ok(())
1270    }
1271
1272    #[test]
1273    fn test_column_or_null_specialization() -> Result<()> {
1274        // create input data
1275        let mut c1 = Int32Builder::new();
1276        let mut c2 = StringBuilder::new();
1277        for i in 0..1000 {
1278            c1.append_value(i);
1279            if i % 7 == 0 {
1280                c2.append_null();
1281            } else {
1282                c2.append_value(format!("string {i}"));
1283            }
1284        }
1285        let c1 = Arc::new(c1.finish());
1286        let c2 = Arc::new(c2.finish());
1287        let schema = Schema::new(vec![
1288            Field::new("c1", DataType::Int32, true),
1289            Field::new("c2", DataType::Utf8, true),
1290        ]);
1291        let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap();
1292
1293        // CaseWhenExprOrNull should produce same results as CaseExpr
1294        let predicate = Arc::new(BinaryExpr::new(
1295            make_col("c1", 0),
1296            Operator::LtEq,
1297            make_lit_i32(250),
1298        ));
1299        let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?;
1300        assert!(matches!(expr.eval_method, EvalMethod::InfallibleExprOrNull));
1301        match expr.evaluate(&batch)? {
1302            ColumnarValue::Array(array) => {
1303                assert_eq!(1000, array.len());
1304                assert_eq!(785, array.null_count());
1305            }
1306            _ => unreachable!(),
1307        }
1308        Ok(())
1309    }
1310
1311    #[test]
1312    fn test_expr_or_expr_specialization() -> Result<()> {
1313        let batch = case_test_batch1()?;
1314        let schema = batch.schema();
1315        let when = binary(
1316            col("a", &schema)?,
1317            Operator::LtEq,
1318            lit(2i32),
1319            &batch.schema(),
1320        )?;
1321        let then = col("b", &schema)?;
1322        let else_expr = col("c", &schema)?;
1323        let expr = CaseExpr::try_new(None, vec![(when, then)], Some(else_expr))?;
1324        assert!(matches!(
1325            expr.eval_method,
1326            EvalMethod::ExpressionOrExpression
1327        ));
1328        let result = expr
1329            .evaluate(&batch)?
1330            .into_array(batch.num_rows())
1331            .expect("Failed to convert to array");
1332        let result = as_int32_array(&result).expect("failed to downcast to Int32Array");
1333
1334        let expected = &Int32Array::from(vec![Some(3), None, Some(777), None]);
1335
1336        assert_eq!(expected, result);
1337        Ok(())
1338    }
1339
1340    fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
1341        Arc::new(Column::new(name, index))
1342    }
1343
1344    fn make_lit_i32(n: i32) -> Arc<dyn PhysicalExpr> {
1345        Arc::new(Literal::new(ScalarValue::Int32(Some(n))))
1346    }
1347
1348    fn generate_case_when_with_type_coercion(
1349        expr: Option<Arc<dyn PhysicalExpr>>,
1350        when_thens: Vec<WhenThen>,
1351        else_expr: Option<Arc<dyn PhysicalExpr>>,
1352        input_schema: &Schema,
1353    ) -> Result<Arc<dyn PhysicalExpr>> {
1354        let coerce_type =
1355            get_case_common_type(&when_thens, else_expr.clone(), input_schema);
1356        let (when_thens, else_expr) = match coerce_type {
1357            None => plan_err!(
1358                "Can't get a common type for then {when_thens:?} and else {else_expr:?} expression"
1359            ),
1360            Some(data_type) => {
1361                // cast then expr
1362                let left = when_thens
1363                    .into_iter()
1364                    .map(|(when, then)| {
1365                        let then = try_cast(then, input_schema, data_type.clone())?;
1366                        Ok((when, then))
1367                    })
1368                    .collect::<Result<Vec<_>>>()?;
1369                let right = match else_expr {
1370                    None => None,
1371                    Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?),
1372                };
1373
1374                Ok((left, right))
1375            }
1376        }?;
1377        case(expr, when_thens, else_expr)
1378    }
1379
1380    fn get_case_common_type(
1381        when_thens: &[WhenThen],
1382        else_expr: Option<Arc<dyn PhysicalExpr>>,
1383        input_schema: &Schema,
1384    ) -> Option<DataType> {
1385        let thens_type = when_thens
1386            .iter()
1387            .map(|when_then| {
1388                let data_type = &when_then.1.data_type(input_schema).unwrap();
1389                data_type.clone()
1390            })
1391            .collect::<Vec<_>>();
1392        let else_type = match else_expr {
1393            None => {
1394                // case when then exprs must have one then value
1395                thens_type[0].clone()
1396            }
1397            Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(),
1398        };
1399        thens_type
1400            .iter()
1401            .try_fold(else_type, |left_type, right_type| {
1402                // TODO: now just use the `equal` coercion rule for case when. If find the issue, and
1403                // refactor again.
1404                comparison_coercion(&left_type, right_type)
1405            })
1406    }
1407
1408    #[test]
1409    fn test_fmt_sql() -> Result<()> {
1410        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
1411
1412        // CASE WHEN a = 'foo' THEN 123.3 ELSE 999 END
1413        let when = binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?;
1414        let then = lit(123.3f64);
1415        let else_value = lit(999i32);
1416
1417        let expr = generate_case_when_with_type_coercion(
1418            None,
1419            vec![(when, then)],
1420            Some(else_value),
1421            &schema,
1422        )?;
1423
1424        let display_string = expr.to_string();
1425        assert_eq!(
1426            display_string,
1427            "CASE WHEN a@0 = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END"
1428        );
1429
1430        let sql_string = fmt_sql(expr.as_ref()).to_string();
1431        assert_eq!(
1432            sql_string,
1433            "CASE WHEN a = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END"
1434        );
1435
1436        Ok(())
1437    }
1438}