datafusion_expr/
conditional_expressions.rs1use crate::expr::Case;
20use crate::{expr_schema::ExprSchemable, Expr};
21use arrow::datatypes::DataType;
22use datafusion_common::{plan_err, DFSchema, HashSet, Result};
23use itertools::Itertools as _;
24
25#[derive(Debug, Clone)]
27pub struct CaseBuilder {
28 expr: Option<Box<Expr>>,
29 when_expr: Vec<Expr>,
30 then_expr: Vec<Expr>,
31 else_expr: Option<Box<Expr>>,
32}
33
34impl CaseBuilder {
35 pub fn new(
36 expr: Option<Box<Expr>>,
37 when_expr: Vec<Expr>,
38 then_expr: Vec<Expr>,
39 else_expr: Option<Box<Expr>>,
40 ) -> Self {
41 Self {
42 expr,
43 when_expr,
44 then_expr,
45 else_expr,
46 }
47 }
48 pub fn when(&mut self, when: Expr, then: Expr) -> CaseBuilder {
49 self.when_expr.push(when);
50 self.then_expr.push(then);
51 CaseBuilder {
52 expr: self.expr.clone(),
53 when_expr: self.when_expr.clone(),
54 then_expr: self.then_expr.clone(),
55 else_expr: self.else_expr.clone(),
56 }
57 }
58 pub fn otherwise(&mut self, else_expr: Expr) -> Result<Expr> {
59 self.else_expr = Some(Box::new(else_expr));
60 self.build()
61 }
62
63 pub fn end(&self) -> Result<Expr> {
64 self.build()
65 }
66
67 fn build(&self) -> Result<Expr> {
68 let mut then_expr = self.then_expr.clone();
70 if let Some(e) = &self.else_expr {
71 then_expr.push(e.as_ref().to_owned());
72 }
73
74 let then_types: Vec<DataType> = then_expr
75 .iter()
76 .map(|e| match e {
77 Expr::Literal(_, _) => e.get_type(&DFSchema::empty()),
78 _ => Ok(DataType::Null),
79 })
80 .collect::<Result<Vec<_>>>()?;
81
82 if then_types.contains(&DataType::Null) {
83 } else {
85 let unique_types: HashSet<&DataType> = then_types.iter().collect();
86 if unique_types.is_empty() {
87 return plan_err!("CASE expression 'then' values had no data types");
88 } else if unique_types.len() != 1 {
89 return plan_err!(
90 "CASE expression 'then' values had multiple data types: {}",
91 unique_types.iter().join(", ")
92 );
93 }
94 }
95
96 Ok(Expr::Case(Case::new(
97 self.expr.clone(),
98 self.when_expr
99 .iter()
100 .zip(self.then_expr.iter())
101 .map(|(w, t)| (Box::new(w.clone()), Box::new(t.clone())))
102 .collect(),
103 self.else_expr.clone(),
104 )))
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use crate::{col, lit, when};
112
113 #[test]
114 fn case_when_same_literal_then_types() -> Result<()> {
115 let _ = when(col("state").eq(lit("CO")), lit(303))
116 .when(col("state").eq(lit("NY")), lit(212))
117 .end()?;
118 Ok(())
119 }
120
121 #[test]
122 fn case_when_different_literal_then_types() {
123 let maybe_expr = when(col("state").eq(lit("CO")), lit(303))
124 .when(col("state").eq(lit("NY")), lit("212"))
125 .end();
126 assert!(maybe_expr.is_err());
127 }
128}