datafusion_comet_spark_expr/conditional_funcs/
if_expr.rs1use arrow::{
19 datatypes::{DataType, Schema},
20 record_batch::RecordBatch,
21};
22use datafusion::logical_expr::ColumnarValue;
23use datafusion_common::Result;
24use datafusion_physical_expr::{expressions::CaseExpr, PhysicalExpr};
25use std::hash::Hash;
26use std::{any::Any, sync::Arc};
27
28#[derive(Debug, Eq)]
31pub struct IfExpr {
32 if_expr: Arc<dyn PhysicalExpr>,
33 true_expr: Arc<dyn PhysicalExpr>,
34 false_expr: Arc<dyn PhysicalExpr>,
35 case_expr: Arc<CaseExpr>,
37}
38
39impl Hash for IfExpr {
40 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
41 self.if_expr.hash(state);
42 self.true_expr.hash(state);
43 self.false_expr.hash(state);
44 self.case_expr.hash(state);
45 }
46}
47impl PartialEq for IfExpr {
48 fn eq(&self, other: &Self) -> bool {
49 self.if_expr.eq(&other.if_expr)
50 && self.true_expr.eq(&other.true_expr)
51 && self.false_expr.eq(&other.false_expr)
52 && self.case_expr.eq(&other.case_expr)
53 }
54}
55
56impl std::fmt::Display for IfExpr {
57 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
58 write!(
59 f,
60 "If [if: {}, true_expr: {}, false_expr: {}]",
61 self.if_expr, self.true_expr, self.false_expr
62 )
63 }
64}
65
66impl IfExpr {
67 pub fn new(
69 if_expr: Arc<dyn PhysicalExpr>,
70 true_expr: Arc<dyn PhysicalExpr>,
71 false_expr: Arc<dyn PhysicalExpr>,
72 ) -> Self {
73 Self {
74 if_expr: Arc::clone(&if_expr),
75 true_expr: Arc::clone(&true_expr),
76 false_expr: Arc::clone(&false_expr),
77 case_expr: Arc::new(
78 CaseExpr::try_new(None, vec![(if_expr, true_expr)], Some(false_expr)).unwrap(),
79 ),
80 }
81 }
82}
83
84impl PhysicalExpr for IfExpr {
85 fn as_any(&self) -> &dyn Any {
87 self
88 }
89
90 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
91 let data_type = self.true_expr.data_type(input_schema)?;
92 Ok(data_type)
93 }
94
95 fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
96 if self.true_expr.nullable(_input_schema)? || self.true_expr.nullable(_input_schema)? {
97 Ok(true)
98 } else {
99 Ok(false)
100 }
101 }
102
103 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
104 self.case_expr.evaluate(batch)
105 }
106
107 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
108 vec![&self.if_expr, &self.true_expr, &self.false_expr]
109 }
110
111 fn with_new_children(
112 self: Arc<Self>,
113 children: Vec<Arc<dyn PhysicalExpr>>,
114 ) -> Result<Arc<dyn PhysicalExpr>> {
115 Ok(Arc::new(IfExpr::new(
116 Arc::clone(&children[0]),
117 Arc::clone(&children[1]),
118 Arc::clone(&children[2]),
119 )))
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use arrow::{array::StringArray, datatypes::*};
126 use arrow_array::Int32Array;
127 use datafusion::logical_expr::Operator;
128 use datafusion_common::cast::as_int32_array;
129 use datafusion_physical_expr::expressions::{binary, col, lit};
130
131 use super::*;
132
133 fn if_fn(
135 if_expr: Arc<dyn PhysicalExpr>,
136 true_expr: Arc<dyn PhysicalExpr>,
137 false_expr: Arc<dyn PhysicalExpr>,
138 ) -> Result<Arc<dyn PhysicalExpr>> {
139 Ok(Arc::new(IfExpr::new(if_expr, true_expr, false_expr)))
140 }
141
142 #[test]
143 fn test_if_1() -> Result<()> {
144 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
145 let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
146 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
147 let schema_ref = batch.schema();
148
149 let if_expr = binary(
151 col("a", &schema_ref)?,
152 Operator::Eq,
153 lit("foo"),
154 &schema_ref,
155 )?;
156 let true_expr = lit(123i32);
157 let false_expr = lit(999i32);
158
159 let expr = if_fn(if_expr, true_expr, false_expr);
160 let result = expr?.evaluate(&batch)?.into_array(batch.num_rows())?;
161 let result = as_int32_array(&result)?;
162
163 let expected = &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(999)]);
164
165 assert_eq!(expected, result);
166
167 Ok(())
168 }
169
170 #[test]
171 fn test_if_2() -> Result<()> {
172 let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
173 let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]);
174 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
175 let schema_ref = batch.schema();
176
177 let if_expr = binary(col("a", &schema_ref)?, Operator::GtEq, lit(1), &schema_ref)?;
179 let true_expr = lit(123i32);
180 let false_expr = lit(999i32);
181
182 let expr = if_fn(if_expr, true_expr, false_expr);
183 let result = expr?.evaluate(&batch)?.into_array(batch.num_rows())?;
184 let result = as_int32_array(&result)?;
185
186 let expected = &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(123)]);
187 assert_eq!(expected, result);
188
189 Ok(())
190 }
191
192 #[test]
193 fn test_if_children() {
194 let if_expr = lit(true);
195 let true_expr = lit(123i32);
196 let false_expr = lit(999i32);
197
198 let expr = if_fn(if_expr, true_expr, false_expr).unwrap();
199 let children = expr.children();
200 assert_eq!(children.len(), 3);
201 assert_eq!(children[0].to_string(), "true");
202 assert_eq!(children[1].to_string(), "123");
203 assert_eq!(children[2].to_string(), "999");
204 }
205}