datafusion_expr/
conditional_expressions.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Conditional expressions
19use crate::expr::Case;
20use crate::{expr_schema::ExprSchemable, Expr};
21use arrow::datatypes::DataType;
22use datafusion_common::{plan_err, DFSchema, HashSet, Result};
23use itertools::Itertools as _;
24
25/// Helper struct for building [Expr::Case]
26#[derive(Debug, Clone)]
27pub struct CaseBuilder {
28    expr: Option<Box<Expr>>,
29    when_expr: Vec<Expr>,
30    then_expr: Vec<Expr>,
31    else_expr: Option<Box<Expr>>,
32}
33
34impl CaseBuilder {
35    pub fn new(
36        expr: Option<Box<Expr>>,
37        when_expr: Vec<Expr>,
38        then_expr: Vec<Expr>,
39        else_expr: Option<Box<Expr>>,
40    ) -> Self {
41        Self {
42            expr,
43            when_expr,
44            then_expr,
45            else_expr,
46        }
47    }
48    pub fn when(&mut self, when: Expr, then: Expr) -> CaseBuilder {
49        self.when_expr.push(when);
50        self.then_expr.push(then);
51        CaseBuilder {
52            expr: self.expr.clone(),
53            when_expr: self.when_expr.clone(),
54            then_expr: self.then_expr.clone(),
55            else_expr: self.else_expr.clone(),
56        }
57    }
58    pub fn otherwise(&mut self, else_expr: Expr) -> Result<Expr> {
59        self.else_expr = Some(Box::new(else_expr));
60        self.build()
61    }
62
63    pub fn end(&self) -> Result<Expr> {
64        self.build()
65    }
66
67    fn build(&self) -> Result<Expr> {
68        // Collect all "then" expressions
69        let mut then_expr = self.then_expr.clone();
70        if let Some(e) = &self.else_expr {
71            then_expr.push(e.as_ref().to_owned());
72        }
73
74        let then_types: Vec<DataType> = then_expr
75            .iter()
76            .map(|e| match e {
77                Expr::Literal(_, _) => e.get_type(&DFSchema::empty()),
78                _ => Ok(DataType::Null),
79            })
80            .collect::<Result<Vec<_>>>()?;
81
82        if then_types.contains(&DataType::Null) {
83            // Cannot verify types until execution type
84        } else {
85            let unique_types: HashSet<&DataType> = then_types.iter().collect();
86            if unique_types.is_empty() {
87                return plan_err!("CASE expression 'then' values had no data types");
88            } else if unique_types.len() != 1 {
89                return plan_err!(
90                    "CASE expression 'then' values had multiple data types: {}",
91                    unique_types.iter().join(", ")
92                );
93            }
94        }
95
96        Ok(Expr::Case(Case::new(
97            self.expr.clone(),
98            self.when_expr
99                .iter()
100                .zip(self.then_expr.iter())
101                .map(|(w, t)| (Box::new(w.clone()), Box::new(t.clone())))
102                .collect(),
103            self.else_expr.clone(),
104        )))
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use crate::{col, lit, when};
112
113    #[test]
114    fn case_when_same_literal_then_types() -> Result<()> {
115        let _ = when(col("state").eq(lit("CO")), lit(303))
116            .when(col("state").eq(lit("NY")), lit(212))
117            .end()?;
118        Ok(())
119    }
120
121    #[test]
122    fn case_when_different_literal_then_types() {
123        let maybe_expr = when(col("state").eq(lit("CO")), lit(303))
124            .when(col("state").eq(lit("NY")), lit("212"))
125            .end();
126        assert!(maybe_expr.is_err());
127    }
128}