datafusion_comet_spark_expr/conditional_funcs/
if_expr.rs1use arrow::{
19 datatypes::{DataType, Schema},
20 record_batch::RecordBatch,
21};
22use datafusion::common::Result;
23use datafusion::logical_expr::ColumnarValue;
24use datafusion::physical_expr::{expressions::CaseExpr, PhysicalExpr};
25use std::fmt::Formatter;
26use std::hash::Hash;
27use std::{any::Any, sync::Arc};
28
29#[derive(Debug, Eq)]
32pub struct IfExpr {
33 if_expr: Arc<dyn PhysicalExpr>,
34 true_expr: Arc<dyn PhysicalExpr>,
35 false_expr: Arc<dyn PhysicalExpr>,
36 case_expr: Arc<CaseExpr>,
38}
39
40impl Hash for IfExpr {
41 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
42 self.if_expr.hash(state);
43 self.true_expr.hash(state);
44 self.false_expr.hash(state);
45 self.case_expr.hash(state);
46 }
47}
48impl PartialEq for IfExpr {
49 fn eq(&self, other: &Self) -> bool {
50 self.if_expr.eq(&other.if_expr)
51 && self.true_expr.eq(&other.true_expr)
52 && self.false_expr.eq(&other.false_expr)
53 && self.case_expr.eq(&other.case_expr)
54 }
55}
56
57impl std::fmt::Display for IfExpr {
58 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
59 write!(
60 f,
61 "If [if: {}, true_expr: {}, false_expr: {}]",
62 self.if_expr, self.true_expr, self.false_expr
63 )
64 }
65}
66
67impl IfExpr {
68 pub fn new(
70 if_expr: Arc<dyn PhysicalExpr>,
71 true_expr: Arc<dyn PhysicalExpr>,
72 false_expr: Arc<dyn PhysicalExpr>,
73 ) -> Self {
74 Self {
75 if_expr: Arc::clone(&if_expr),
76 true_expr: Arc::clone(&true_expr),
77 false_expr: Arc::clone(&false_expr),
78 case_expr: Arc::new(
79 CaseExpr::try_new(None, vec![(if_expr, true_expr)], Some(false_expr)).unwrap(),
80 ),
81 }
82 }
83}
84
85impl PhysicalExpr for IfExpr {
86 fn as_any(&self) -> &dyn Any {
88 self
89 }
90
91 fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
92 unimplemented!()
93 }
94
95 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
96 let data_type = self.true_expr.data_type(input_schema)?;
97 Ok(data_type)
98 }
99
100 fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
101 if self.true_expr.nullable(_input_schema)? || self.true_expr.nullable(_input_schema)? {
102 Ok(true)
103 } else {
104 Ok(false)
105 }
106 }
107
108 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
109 self.case_expr.evaluate(batch)
110 }
111
112 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
113 vec![&self.if_expr, &self.true_expr, &self.false_expr]
114 }
115
116 fn with_new_children(
117 self: Arc<Self>,
118 children: Vec<Arc<dyn PhysicalExpr>>,
119 ) -> Result<Arc<dyn PhysicalExpr>> {
120 Ok(Arc::new(IfExpr::new(
121 Arc::clone(&children[0]),
122 Arc::clone(&children[1]),
123 Arc::clone(&children[2]),
124 )))
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use arrow::array::Int32Array;
131 use arrow::{array::StringArray, datatypes::*};
132 use datafusion::common::cast::as_int32_array;
133 use datafusion::logical_expr::Operator;
134 use datafusion::physical_expr::expressions::{binary, col, lit};
135
136 use super::*;
137
138 fn if_fn(
140 if_expr: Arc<dyn PhysicalExpr>,
141 true_expr: Arc<dyn PhysicalExpr>,
142 false_expr: Arc<dyn PhysicalExpr>,
143 ) -> Result<Arc<dyn PhysicalExpr>> {
144 Ok(Arc::new(IfExpr::new(if_expr, true_expr, false_expr)))
145 }
146
147 #[test]
148 fn test_if_1() -> Result<()> {
149 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
150 let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
151 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
152 let schema_ref = batch.schema();
153
154 let if_expr = binary(
156 col("a", &schema_ref)?,
157 Operator::Eq,
158 lit("foo"),
159 &schema_ref,
160 )?;
161 let true_expr = lit(123i32);
162 let false_expr = lit(999i32);
163
164 let expr = if_fn(if_expr, true_expr, false_expr);
165 let result = expr?.evaluate(&batch)?.into_array(batch.num_rows())?;
166 let result = as_int32_array(&result)?;
167
168 let expected = &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(999)]);
169
170 assert_eq!(expected, result);
171
172 Ok(())
173 }
174
175 #[test]
176 fn test_if_2() -> Result<()> {
177 let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
178 let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]);
179 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
180 let schema_ref = batch.schema();
181
182 let if_expr = binary(col("a", &schema_ref)?, Operator::GtEq, lit(1), &schema_ref)?;
184 let true_expr = lit(123i32);
185 let false_expr = lit(999i32);
186
187 let expr = if_fn(if_expr, true_expr, false_expr);
188 let result = expr?.evaluate(&batch)?.into_array(batch.num_rows())?;
189 let result = as_int32_array(&result)?;
190
191 let expected = &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(123)]);
192 assert_eq!(expected, result);
193
194 Ok(())
195 }
196
197 #[test]
198 fn test_if_children() {
199 let if_expr = lit(true);
200 let true_expr = lit(123i32);
201 let false_expr = lit(999i32);
202
203 let expr = if_fn(if_expr, true_expr, false_expr).unwrap();
204 let children = expr.children();
205 assert_eq!(children.len(), 3);
206 assert_eq!(children[0].to_string(), "true");
207 assert_eq!(children[1].to_string(), "123");
208 assert_eq!(children[2].to_string(), "999");
209 }
210}