datafusion_comet_spark_expr/conditional_funcs/
if_expr.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::{
19    datatypes::{DataType, Schema},
20    record_batch::RecordBatch,
21};
22use datafusion::common::Result;
23use datafusion::logical_expr::ColumnarValue;
24use datafusion::physical_expr::{expressions::CaseExpr, PhysicalExpr};
25use std::fmt::Formatter;
26use std::hash::Hash;
27use std::{any::Any, sync::Arc};
28
29/// IfExpr is a wrapper around CaseExpr, because `IF(a, b, c)` is semantically equivalent to
30/// `CASE WHEN a THEN b ELSE c END`.
31#[derive(Debug, Eq)]
32pub struct IfExpr {
33    if_expr: Arc<dyn PhysicalExpr>,
34    true_expr: Arc<dyn PhysicalExpr>,
35    false_expr: Arc<dyn PhysicalExpr>,
36    // we delegate to case_expr for evaluation
37    case_expr: Arc<CaseExpr>,
38}
39
40impl Hash for IfExpr {
41    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
42        self.if_expr.hash(state);
43        self.true_expr.hash(state);
44        self.false_expr.hash(state);
45        self.case_expr.hash(state);
46    }
47}
48impl PartialEq for IfExpr {
49    fn eq(&self, other: &Self) -> bool {
50        self.if_expr.eq(&other.if_expr)
51            && self.true_expr.eq(&other.true_expr)
52            && self.false_expr.eq(&other.false_expr)
53            && self.case_expr.eq(&other.case_expr)
54    }
55}
56
57impl std::fmt::Display for IfExpr {
58    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
59        write!(
60            f,
61            "If [if: {}, true_expr: {}, false_expr: {}]",
62            self.if_expr, self.true_expr, self.false_expr
63        )
64    }
65}
66
67impl IfExpr {
68    /// Create a new IF expression
69    pub fn new(
70        if_expr: Arc<dyn PhysicalExpr>,
71        true_expr: Arc<dyn PhysicalExpr>,
72        false_expr: Arc<dyn PhysicalExpr>,
73    ) -> Self {
74        Self {
75            if_expr: Arc::clone(&if_expr),
76            true_expr: Arc::clone(&true_expr),
77            false_expr: Arc::clone(&false_expr),
78            case_expr: Arc::new(
79                CaseExpr::try_new(None, vec![(if_expr, true_expr)], Some(false_expr)).unwrap(),
80            ),
81        }
82    }
83}
84
85impl PhysicalExpr for IfExpr {
86    /// Return a reference to Any that can be used for down-casting
87    fn as_any(&self) -> &dyn Any {
88        self
89    }
90
91    fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
92        unimplemented!()
93    }
94
95    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
96        let data_type = self.true_expr.data_type(input_schema)?;
97        Ok(data_type)
98    }
99
100    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
101        if self.true_expr.nullable(_input_schema)? || self.true_expr.nullable(_input_schema)? {
102            Ok(true)
103        } else {
104            Ok(false)
105        }
106    }
107
108    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
109        self.case_expr.evaluate(batch)
110    }
111
112    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
113        vec![&self.if_expr, &self.true_expr, &self.false_expr]
114    }
115
116    fn with_new_children(
117        self: Arc<Self>,
118        children: Vec<Arc<dyn PhysicalExpr>>,
119    ) -> Result<Arc<dyn PhysicalExpr>> {
120        Ok(Arc::new(IfExpr::new(
121            Arc::clone(&children[0]),
122            Arc::clone(&children[1]),
123            Arc::clone(&children[2]),
124        )))
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use arrow::array::Int32Array;
131    use arrow::{array::StringArray, datatypes::*};
132    use datafusion::common::cast::as_int32_array;
133    use datafusion::logical_expr::Operator;
134    use datafusion::physical_expr::expressions::{binary, col, lit};
135
136    use super::*;
137
138    /// Create an If expression
139    fn if_fn(
140        if_expr: Arc<dyn PhysicalExpr>,
141        true_expr: Arc<dyn PhysicalExpr>,
142        false_expr: Arc<dyn PhysicalExpr>,
143    ) -> Result<Arc<dyn PhysicalExpr>> {
144        Ok(Arc::new(IfExpr::new(if_expr, true_expr, false_expr)))
145    }
146
147    #[test]
148    fn test_if_1() -> Result<()> {
149        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
150        let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
151        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
152        let schema_ref = batch.schema();
153
154        // if a = 'foo' 123 else 999
155        let if_expr = binary(
156            col("a", &schema_ref)?,
157            Operator::Eq,
158            lit("foo"),
159            &schema_ref,
160        )?;
161        let true_expr = lit(123i32);
162        let false_expr = lit(999i32);
163
164        let expr = if_fn(if_expr, true_expr, false_expr);
165        let result = expr?.evaluate(&batch)?.into_array(batch.num_rows())?;
166        let result = as_int32_array(&result)?;
167
168        let expected = &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(999)]);
169
170        assert_eq!(expected, result);
171
172        Ok(())
173    }
174
175    #[test]
176    fn test_if_2() -> Result<()> {
177        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
178        let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]);
179        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
180        let schema_ref = batch.schema();
181
182        // if a >=1 123 else 999
183        let if_expr = binary(col("a", &schema_ref)?, Operator::GtEq, lit(1), &schema_ref)?;
184        let true_expr = lit(123i32);
185        let false_expr = lit(999i32);
186
187        let expr = if_fn(if_expr, true_expr, false_expr);
188        let result = expr?.evaluate(&batch)?.into_array(batch.num_rows())?;
189        let result = as_int32_array(&result)?;
190
191        let expected = &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(123)]);
192        assert_eq!(expected, result);
193
194        Ok(())
195    }
196
197    #[test]
198    fn test_if_children() {
199        let if_expr = lit(true);
200        let true_expr = lit(123i32);
201        let false_expr = lit(999i32);
202
203        let expr = if_fn(if_expr, true_expr, false_expr).unwrap();
204        let children = expr.children();
205        assert_eq!(children.len(), 3);
206        assert_eq!(children[0].to_string(), "true");
207        assert_eq!(children[1].to_string(), "123");
208        assert_eq!(children[2].to_string(), "999");
209    }
210}