1use arrow::datatypes::Schema;
21use datafusion_common::{Result, tree_node::TreeNode};
22use std::sync::Arc;
23
24use crate::{PhysicalExpr, simplifier::not::simplify_not_expr};
25
26pub mod const_evaluator;
27pub mod not;
28pub mod unwrap_cast;
29
30const MAX_LOOP_COUNT: usize = 5;
31
32pub struct PhysicalExprSimplifier<'a> {
38 schema: &'a Schema,
39}
40
41impl<'a> PhysicalExprSimplifier<'a> {
42 pub fn new(schema: &'a Schema) -> Self {
44 Self { schema }
45 }
46
47 pub fn simplify(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
49 let mut current_expr = expr;
50 let mut count = 0;
51 let schema = self.schema;
52
53 while count < MAX_LOOP_COUNT {
54 count += 1;
55 let result = current_expr.transform(|node| {
56 #[cfg(test)]
57 let original_type = node.data_type(schema).unwrap();
58
59 let rewritten = simplify_not_expr(&node, schema)?
62 .transform_data(|node| {
63 unwrap_cast::unwrap_cast_in_comparison(node, schema)
64 })?
65 .transform_data(|node| const_evaluator::simplify_const_expr(&node))?;
66
67 #[cfg(test)]
68 assert_eq!(
69 rewritten.data.data_type(schema).unwrap(),
70 original_type,
71 "Simplified expression should have the same data type as the original"
72 );
73
74 Ok(rewritten)
75 })?;
76
77 if !result.transformed {
78 return Ok(result.data);
79 }
80 current_expr = result.data;
81 }
82 Ok(current_expr)
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89 use crate::expressions::{
90 BinaryExpr, CastExpr, Literal, NotExpr, TryCastExpr, col, in_list, lit,
91 };
92 use arrow::datatypes::{DataType, Field, Schema};
93 use datafusion_common::ScalarValue;
94 use datafusion_expr::Operator;
95
96 fn test_schema() -> Schema {
97 Schema::new(vec![
98 Field::new("c1", DataType::Int32, false),
99 Field::new("c2", DataType::Int64, false),
100 Field::new("c3", DataType::Utf8, false),
101 ])
102 }
103
104 fn not_test_schema() -> Schema {
105 Schema::new(vec![
106 Field::new("a", DataType::Boolean, false),
107 Field::new("b", DataType::Boolean, false),
108 Field::new("c", DataType::Int32, false),
109 ])
110 }
111
112 fn as_literal(expr: &Arc<dyn PhysicalExpr>) -> &Literal {
114 expr.as_any()
115 .downcast_ref::<Literal>()
116 .unwrap_or_else(|| panic!("Expected Literal, got: {expr}"))
117 }
118
119 fn as_binary(expr: &Arc<dyn PhysicalExpr>) -> &BinaryExpr {
121 expr.as_any()
122 .downcast_ref::<BinaryExpr>()
123 .unwrap_or_else(|| panic!("Expected BinaryExpr, got: {expr}"))
124 }
125
126 fn assert_not_simplify(
128 simplifier: &PhysicalExprSimplifier,
129 input: Arc<dyn PhysicalExpr>,
130 expected: Arc<dyn PhysicalExpr>,
131 ) {
132 let result = simplifier.simplify(Arc::clone(&input)).unwrap();
133 assert_eq!(
134 &result, &expected,
135 "Simplification should transform:\n input: {input}\n to: {expected}\n got: {result}"
136 );
137 }
138
139 #[test]
140 fn test_simplify() {
141 let schema = test_schema();
142 let simplifier = PhysicalExprSimplifier::new(&schema);
143
144 let column_expr = col("c2", &schema).unwrap();
146 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int32, None));
147 let literal_expr = lit(ScalarValue::Int32(Some(99)));
148 let binary_expr =
149 Arc::new(BinaryExpr::new(cast_expr, Operator::NotEq, literal_expr));
150
151 let optimized = simplifier.simplify(binary_expr).unwrap();
153
154 let optimized_binary = as_binary(&optimized);
155
156 let left_expr = optimized_binary.left();
158 assert!(
159 left_expr.as_any().downcast_ref::<CastExpr>().is_none()
160 && left_expr.as_any().downcast_ref::<TryCastExpr>().is_none()
161 );
162 let right_literal = as_literal(optimized_binary.right());
163 assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(99)));
164 }
165
166 #[test]
167 fn test_nested_expression_simplification() {
168 let schema = test_schema();
169 let simplifier = PhysicalExprSimplifier::new(&schema);
170
171 let c1_expr = col("c1", &schema).unwrap();
173 let c1_cast = Arc::new(CastExpr::new(c1_expr, DataType::Int64, None));
174 let c1_literal = lit(ScalarValue::Int64(Some(5)));
175 let c1_binary = Arc::new(BinaryExpr::new(c1_cast, Operator::Gt, c1_literal));
176
177 let c2_expr = col("c2", &schema).unwrap();
178 let c2_cast = Arc::new(CastExpr::new(c2_expr, DataType::Int32, None));
179 let c2_literal = lit(ScalarValue::Int32(Some(10)));
180 let c2_binary = Arc::new(BinaryExpr::new(c2_cast, Operator::LtEq, c2_literal));
181
182 let or_expr = Arc::new(BinaryExpr::new(c1_binary, Operator::Or, c2_binary));
183
184 let optimized = simplifier.simplify(or_expr).unwrap();
186
187 let or_binary = as_binary(&optimized);
188
189 let left_binary = as_binary(or_binary.left());
191 let left_left_expr = left_binary.left();
192 assert!(
193 left_left_expr.as_any().downcast_ref::<CastExpr>().is_none()
194 && left_left_expr
195 .as_any()
196 .downcast_ref::<TryCastExpr>()
197 .is_none()
198 );
199 let left_literal = as_literal(left_binary.right());
200 assert_eq!(left_literal.value(), &ScalarValue::Int32(Some(5)));
201
202 let right_binary = as_binary(or_binary.right());
204 let right_left_expr = right_binary.left();
205 assert!(
206 right_left_expr
207 .as_any()
208 .downcast_ref::<CastExpr>()
209 .is_none()
210 && right_left_expr
211 .as_any()
212 .downcast_ref::<TryCastExpr>()
213 .is_none()
214 );
215 let right_literal = as_literal(right_binary.right());
216 assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(10)));
217 }
218
219 #[test]
220 fn test_double_negation_elimination() -> Result<()> {
221 let schema = not_test_schema();
222 let simplifier = PhysicalExprSimplifier::new(&schema);
223
224 let inner_expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
226 col("c", &schema)?,
227 Operator::Gt,
228 lit(ScalarValue::Int32(Some(5))),
229 ));
230 let inner_not = Arc::new(NotExpr::new(Arc::clone(&inner_expr)));
231 let double_not: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(inner_not));
232
233 let expected = inner_expr;
234 assert_not_simplify(&simplifier, double_not, expected);
235 Ok(())
236 }
237
238 #[test]
239 fn test_not_literal() -> Result<()> {
240 let schema = not_test_schema();
241 let simplifier = PhysicalExprSimplifier::new(&schema);
242
243 let not_true = Arc::new(NotExpr::new(lit(ScalarValue::Boolean(Some(true)))));
245 let expected = lit(ScalarValue::Boolean(Some(false)));
246 assert_not_simplify(&simplifier, not_true, expected);
247
248 let not_false = Arc::new(NotExpr::new(lit(ScalarValue::Boolean(Some(false)))));
250 let expected = lit(ScalarValue::Boolean(Some(true)));
251 assert_not_simplify(&simplifier, not_false, expected);
252
253 Ok(())
254 }
255
256 #[test]
257 fn test_negate_comparison() -> Result<()> {
258 let schema = not_test_schema();
259 let simplifier = PhysicalExprSimplifier::new(&schema);
260
261 let not_eq = Arc::new(NotExpr::new(Arc::new(BinaryExpr::new(
263 col("c", &schema)?,
264 Operator::Eq,
265 lit(ScalarValue::Int32(Some(5))),
266 ))));
267 let expected = Arc::new(BinaryExpr::new(
268 col("c", &schema)?,
269 Operator::NotEq,
270 lit(ScalarValue::Int32(Some(5))),
271 ));
272 assert_not_simplify(&simplifier, not_eq, expected);
273
274 Ok(())
275 }
276
277 #[test]
278 fn test_demorgans_law_and() -> Result<()> {
279 let schema = not_test_schema();
280 let simplifier = PhysicalExprSimplifier::new(&schema);
281
282 let and_expr = Arc::new(BinaryExpr::new(
284 col("a", &schema)?,
285 Operator::And,
286 col("b", &schema)?,
287 ));
288 let not_and: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(and_expr));
289
290 let expected: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
291 Arc::new(NotExpr::new(col("a", &schema)?)),
292 Operator::Or,
293 Arc::new(NotExpr::new(col("b", &schema)?)),
294 ));
295 assert_not_simplify(&simplifier, not_and, expected);
296
297 Ok(())
298 }
299
300 #[test]
301 fn test_demorgans_law_or() -> Result<()> {
302 let schema = not_test_schema();
303 let simplifier = PhysicalExprSimplifier::new(&schema);
304
305 let or_expr = Arc::new(BinaryExpr::new(
307 col("a", &schema)?,
308 Operator::Or,
309 col("b", &schema)?,
310 ));
311 let not_or: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(or_expr));
312
313 let expected: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
314 Arc::new(NotExpr::new(col("a", &schema)?)),
315 Operator::And,
316 Arc::new(NotExpr::new(col("b", &schema)?)),
317 ));
318 assert_not_simplify(&simplifier, not_or, expected);
319
320 Ok(())
321 }
322
323 #[test]
324 fn test_demorgans_with_comparison_simplification() -> Result<()> {
325 let schema = not_test_schema();
326 let simplifier = PhysicalExprSimplifier::new(&schema);
327
328 let eq1 = Arc::new(BinaryExpr::new(
330 col("c", &schema)?,
331 Operator::Eq,
332 lit(ScalarValue::Int32(Some(1))),
333 ));
334 let eq2 = Arc::new(BinaryExpr::new(
335 col("c", &schema)?,
336 Operator::Eq,
337 lit(ScalarValue::Int32(Some(2))),
338 ));
339 let and_expr = Arc::new(BinaryExpr::new(eq1, Operator::And, eq2));
340 let not_and: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(and_expr));
341
342 let expected: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
343 Arc::new(BinaryExpr::new(
344 col("c", &schema)?,
345 Operator::NotEq,
346 lit(ScalarValue::Int32(Some(1))),
347 )),
348 Operator::Or,
349 Arc::new(BinaryExpr::new(
350 col("c", &schema)?,
351 Operator::NotEq,
352 lit(ScalarValue::Int32(Some(2))),
353 )),
354 ));
355 assert_not_simplify(&simplifier, not_and, expected);
356
357 Ok(())
358 }
359
360 #[test]
361 fn test_not_of_not_and_not() -> Result<()> {
362 let schema = not_test_schema();
363 let simplifier = PhysicalExprSimplifier::new(&schema);
364
365 let not_a = Arc::new(NotExpr::new(col("a", &schema)?));
367 let not_b = Arc::new(NotExpr::new(col("b", &schema)?));
368 let and_expr = Arc::new(BinaryExpr::new(not_a, Operator::And, not_b));
369 let not_and: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(and_expr));
370
371 let expected: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
372 col("a", &schema)?,
373 Operator::Or,
374 col("b", &schema)?,
375 ));
376 assert_not_simplify(&simplifier, not_and, expected);
377
378 Ok(())
379 }
380
381 #[test]
382 fn test_not_in_list() -> Result<()> {
383 let schema = not_test_schema();
384 let simplifier = PhysicalExprSimplifier::new(&schema);
385
386 let list = vec![
388 lit(ScalarValue::Int32(Some(1))),
389 lit(ScalarValue::Int32(Some(2))),
390 lit(ScalarValue::Int32(Some(3))),
391 ];
392 let in_list_expr = in_list(col("c", &schema)?, list.clone(), &false, &schema)?;
393 let not_in: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(in_list_expr));
394
395 let expected = in_list(col("c", &schema)?, list, &true, &schema)?;
396 assert_not_simplify(&simplifier, not_in, expected);
397
398 Ok(())
399 }
400
401 #[test]
402 fn test_not_not_in_list() -> Result<()> {
403 let schema = not_test_schema();
404 let simplifier = PhysicalExprSimplifier::new(&schema);
405
406 let list = vec![
408 lit(ScalarValue::Int32(Some(1))),
409 lit(ScalarValue::Int32(Some(2))),
410 lit(ScalarValue::Int32(Some(3))),
411 ];
412 let not_in_list_expr = in_list(col("c", &schema)?, list.clone(), &true, &schema)?;
413 let not_not_in: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(not_in_list_expr));
414
415 let expected = in_list(col("c", &schema)?, list, &false, &schema)?;
416 assert_not_simplify(&simplifier, not_not_in, expected);
417
418 Ok(())
419 }
420
421 #[test]
422 fn test_double_not_in_list() -> Result<()> {
423 let schema = not_test_schema();
424 let simplifier = PhysicalExprSimplifier::new(&schema);
425
426 let list = vec![
428 lit(ScalarValue::Int32(Some(1))),
429 lit(ScalarValue::Int32(Some(2))),
430 lit(ScalarValue::Int32(Some(3))),
431 ];
432 let in_list_expr = in_list(col("c", &schema)?, list.clone(), &false, &schema)?;
433 let not_in = Arc::new(NotExpr::new(in_list_expr));
434 let double_not: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(not_in));
435
436 let expected = in_list(col("c", &schema)?, list, &false, &schema)?;
437 assert_not_simplify(&simplifier, double_not, expected);
438
439 Ok(())
440 }
441
442 #[test]
443 fn test_deeply_nested_not() -> Result<()> {
444 let schema = not_test_schema();
445 let simplifier = PhysicalExprSimplifier::new(&schema);
446
447 let inner_expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
452 col("c", &schema)?,
453 Operator::Gt,
454 lit(ScalarValue::Int32(Some(5))),
455 ));
456
457 let mut expr = Arc::clone(&inner_expr);
458 for _ in 0..200 {
460 expr = Arc::new(NotExpr::new(expr));
461 }
462
463 let expected = inner_expr;
465 assert_not_simplify(&simplifier, Arc::clone(&expr), expected);
466
467 while let Some(not_expr) = expr.as_any().downcast_ref::<NotExpr>() {
472 let child = Arc::clone(not_expr.arg());
475
476 expr = child;
483 }
484
485 Ok(())
486 }
487
488 #[test]
489 fn test_simplify_literal_binary_expr() {
490 let schema = Schema::empty();
491 let simplifier = PhysicalExprSimplifier::new(&schema);
492
493 let expr: Arc<dyn PhysicalExpr> =
495 Arc::new(BinaryExpr::new(lit(1i32), Operator::Plus, lit(2i32)));
496 let result = simplifier.simplify(expr).unwrap();
497 let literal = as_literal(&result);
498 assert_eq!(literal.value(), &ScalarValue::Int32(Some(3)));
499 }
500
501 #[test]
502 fn test_simplify_literal_comparison() {
503 let schema = Schema::empty();
504 let simplifier = PhysicalExprSimplifier::new(&schema);
505
506 let expr: Arc<dyn PhysicalExpr> =
508 Arc::new(BinaryExpr::new(lit(5i32), Operator::Gt, lit(3i32)));
509 let result = simplifier.simplify(expr).unwrap();
510 let literal = as_literal(&result);
511 assert_eq!(literal.value(), &ScalarValue::Boolean(Some(true)));
512
513 let expr: Arc<dyn PhysicalExpr> =
515 Arc::new(BinaryExpr::new(lit(2i32), Operator::Gt, lit(3i32)));
516 let result = simplifier.simplify(expr).unwrap();
517 let literal = as_literal(&result);
518 assert_eq!(literal.value(), &ScalarValue::Boolean(Some(false)));
519 }
520
521 #[test]
522 fn test_simplify_nested_literal_expr() {
523 let schema = Schema::empty();
524 let simplifier = PhysicalExprSimplifier::new(&schema);
525
526 let inner: Arc<dyn PhysicalExpr> =
528 Arc::new(BinaryExpr::new(lit(1i32), Operator::Plus, lit(2i32)));
529 let expr: Arc<dyn PhysicalExpr> =
530 Arc::new(BinaryExpr::new(inner, Operator::Multiply, lit(3i32)));
531 let result = simplifier.simplify(expr).unwrap();
532 let literal = as_literal(&result);
533 assert_eq!(literal.value(), &ScalarValue::Int32(Some(9)));
534 }
535
536 #[test]
537 fn test_simplify_deeply_nested_literals() {
538 let schema = Schema::empty();
539 let simplifier = PhysicalExprSimplifier::new(&schema);
540
541 let left: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
543 Arc::new(BinaryExpr::new(lit(1i32), Operator::Plus, lit(2i32))),
544 Operator::Multiply,
545 lit(3i32),
546 ));
547 let right: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
548 Arc::new(BinaryExpr::new(lit(4i32), Operator::Minus, lit(1i32))),
549 Operator::Multiply,
550 lit(2i32),
551 ));
552 let expr: Arc<dyn PhysicalExpr> =
553 Arc::new(BinaryExpr::new(left, Operator::Plus, right));
554 let result = simplifier.simplify(expr).unwrap();
555 let literal = as_literal(&result);
556 assert_eq!(literal.value(), &ScalarValue::Int32(Some(15)));
557 }
558
559 #[test]
560 fn test_no_simplify_with_column() {
561 let schema = test_schema();
562 let simplifier = PhysicalExprSimplifier::new(&schema);
563
564 let expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
566 col("c1", &schema).unwrap(),
567 Operator::Plus,
568 lit(2i32),
569 ));
570 let result = simplifier.simplify(expr).unwrap();
571 assert!(result.as_any().downcast_ref::<BinaryExpr>().is_some());
573 }
574
575 #[test]
576 fn test_partial_simplify_with_column() {
577 let schema = test_schema();
578 let simplifier = PhysicalExprSimplifier::new(&schema);
579
580 let literal_part: Arc<dyn PhysicalExpr> =
582 Arc::new(BinaryExpr::new(lit(1i32), Operator::Plus, lit(2i32)));
583 let expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
584 literal_part,
585 Operator::Plus,
586 col("c1", &schema).unwrap(),
587 ));
588 let result = simplifier.simplify(expr).unwrap();
589
590 let binary = as_binary(&result);
592 let left_literal = as_literal(binary.left());
593 assert_eq!(left_literal.value(), &ScalarValue::Int32(Some(3)));
594 }
595
596 #[test]
597 fn test_simplify_literal_string_concat() {
598 let schema = Schema::empty();
599 let simplifier = PhysicalExprSimplifier::new(&schema);
600
601 let expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
603 lit("hello"),
604 Operator::StringConcat,
605 lit(" world"),
606 ));
607 let result = simplifier.simplify(expr).unwrap();
608 let literal = as_literal(&result);
609 assert_eq!(
610 literal.value(),
611 &ScalarValue::Utf8(Some("hello world".to_string()))
612 );
613 }
614}