datafusion_comet_spark_expr/conditional_funcs/
if_expr.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::{
19    datatypes::{DataType, Schema},
20    record_batch::RecordBatch,
21};
22use datafusion::logical_expr::ColumnarValue;
23use datafusion_common::Result;
24use datafusion_physical_expr::{expressions::CaseExpr, PhysicalExpr};
25use std::hash::Hash;
26use std::{any::Any, sync::Arc};
27
28/// IfExpr is a wrapper around CaseExpr, because `IF(a, b, c)` is semantically equivalent to
29/// `CASE WHEN a THEN b ELSE c END`.
30#[derive(Debug, Eq)]
31pub struct IfExpr {
32    if_expr: Arc<dyn PhysicalExpr>,
33    true_expr: Arc<dyn PhysicalExpr>,
34    false_expr: Arc<dyn PhysicalExpr>,
35    // we delegate to case_expr for evaluation
36    case_expr: Arc<CaseExpr>,
37}
38
39impl Hash for IfExpr {
40    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
41        self.if_expr.hash(state);
42        self.true_expr.hash(state);
43        self.false_expr.hash(state);
44        self.case_expr.hash(state);
45    }
46}
47impl PartialEq for IfExpr {
48    fn eq(&self, other: &Self) -> bool {
49        self.if_expr.eq(&other.if_expr)
50            && self.true_expr.eq(&other.true_expr)
51            && self.false_expr.eq(&other.false_expr)
52            && self.case_expr.eq(&other.case_expr)
53    }
54}
55
56impl std::fmt::Display for IfExpr {
57    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
58        write!(
59            f,
60            "If [if: {}, true_expr: {}, false_expr: {}]",
61            self.if_expr, self.true_expr, self.false_expr
62        )
63    }
64}
65
66impl IfExpr {
67    /// Create a new IF expression
68    pub fn new(
69        if_expr: Arc<dyn PhysicalExpr>,
70        true_expr: Arc<dyn PhysicalExpr>,
71        false_expr: Arc<dyn PhysicalExpr>,
72    ) -> Self {
73        Self {
74            if_expr: Arc::clone(&if_expr),
75            true_expr: Arc::clone(&true_expr),
76            false_expr: Arc::clone(&false_expr),
77            case_expr: Arc::new(
78                CaseExpr::try_new(None, vec![(if_expr, true_expr)], Some(false_expr)).unwrap(),
79            ),
80        }
81    }
82}
83
84impl PhysicalExpr for IfExpr {
85    /// Return a reference to Any that can be used for down-casting
86    fn as_any(&self) -> &dyn Any {
87        self
88    }
89
90    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
91        let data_type = self.true_expr.data_type(input_schema)?;
92        Ok(data_type)
93    }
94
95    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
96        if self.true_expr.nullable(_input_schema)? || self.true_expr.nullable(_input_schema)? {
97            Ok(true)
98        } else {
99            Ok(false)
100        }
101    }
102
103    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
104        self.case_expr.evaluate(batch)
105    }
106
107    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
108        vec![&self.if_expr, &self.true_expr, &self.false_expr]
109    }
110
111    fn with_new_children(
112        self: Arc<Self>,
113        children: Vec<Arc<dyn PhysicalExpr>>,
114    ) -> Result<Arc<dyn PhysicalExpr>> {
115        Ok(Arc::new(IfExpr::new(
116            Arc::clone(&children[0]),
117            Arc::clone(&children[1]),
118            Arc::clone(&children[2]),
119        )))
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use arrow::{array::StringArray, datatypes::*};
126    use arrow_array::Int32Array;
127    use datafusion::logical_expr::Operator;
128    use datafusion_common::cast::as_int32_array;
129    use datafusion_physical_expr::expressions::{binary, col, lit};
130
131    use super::*;
132
133    /// Create an If expression
134    fn if_fn(
135        if_expr: Arc<dyn PhysicalExpr>,
136        true_expr: Arc<dyn PhysicalExpr>,
137        false_expr: Arc<dyn PhysicalExpr>,
138    ) -> Result<Arc<dyn PhysicalExpr>> {
139        Ok(Arc::new(IfExpr::new(if_expr, true_expr, false_expr)))
140    }
141
142    #[test]
143    fn test_if_1() -> Result<()> {
144        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
145        let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
146        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
147        let schema_ref = batch.schema();
148
149        // if a = 'foo' 123 else 999
150        let if_expr = binary(
151            col("a", &schema_ref)?,
152            Operator::Eq,
153            lit("foo"),
154            &schema_ref,
155        )?;
156        let true_expr = lit(123i32);
157        let false_expr = lit(999i32);
158
159        let expr = if_fn(if_expr, true_expr, false_expr);
160        let result = expr?.evaluate(&batch)?.into_array(batch.num_rows())?;
161        let result = as_int32_array(&result)?;
162
163        let expected = &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(999)]);
164
165        assert_eq!(expected, result);
166
167        Ok(())
168    }
169
170    #[test]
171    fn test_if_2() -> Result<()> {
172        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
173        let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]);
174        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
175        let schema_ref = batch.schema();
176
177        // if a >=1 123 else 999
178        let if_expr = binary(col("a", &schema_ref)?, Operator::GtEq, lit(1), &schema_ref)?;
179        let true_expr = lit(123i32);
180        let false_expr = lit(999i32);
181
182        let expr = if_fn(if_expr, true_expr, false_expr);
183        let result = expr?.evaluate(&batch)?.into_array(batch.num_rows())?;
184        let result = as_int32_array(&result)?;
185
186        let expected = &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(123)]);
187        assert_eq!(expected, result);
188
189        Ok(())
190    }
191
192    #[test]
193    fn test_if_children() {
194        let if_expr = lit(true);
195        let true_expr = lit(123i32);
196        let false_expr = lit(999i32);
197
198        let expr = if_fn(if_expr, true_expr, false_expr).unwrap();
199        let children = expr.children();
200        assert_eq!(children.len(), 3);
201        assert_eq!(children[0].to_string(), "true");
202        assert_eq!(children[1].to_string(), "123");
203        assert_eq!(children[2].to_string(), "999");
204    }
205}