1use arrow::datatypes::Schema;
21use datafusion_common::{Result, tree_node::TreeNode};
22use std::sync::Arc;
23
24use crate::{
25 PhysicalExpr,
26 simplifier::{
27 const_evaluator::create_dummy_batch, unwrap_cast::unwrap_cast_in_comparison,
28 },
29};
30
31pub mod const_evaluator;
32pub mod not;
33pub mod unwrap_cast;
34
35const MAX_LOOP_COUNT: usize = 5;
36
37pub struct PhysicalExprSimplifier<'a> {
43 schema: &'a Schema,
44}
45
46impl<'a> PhysicalExprSimplifier<'a> {
47 pub fn new(schema: &'a Schema) -> Self {
49 Self { schema }
50 }
51
52 pub fn simplify(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
54 let mut current_expr = expr;
55 let mut count = 0;
56 let schema = self.schema;
57
58 let batch = create_dummy_batch()?;
59
60 while count < MAX_LOOP_COUNT {
61 count += 1;
62 let result = current_expr.transform(|node| {
63 #[cfg(debug_assertions)]
64 let original_type = node.data_type(schema).unwrap();
65
66 #[expect(deprecated, reason = "`simplify_not_expr` is marked as deprecated until it's made private.")]
69 let rewritten = not::simplify_not_expr(node, schema)?
70 .transform_data(|node| unwrap_cast_in_comparison(node, schema))?
71 .transform_data(|node| {
72 const_evaluator::simplify_const_expr_immediate(node, batch)
73 })?;
74
75 #[cfg(debug_assertions)]
76 assert_eq!(
77 rewritten.data.data_type(schema).unwrap(),
78 original_type,
79 "Simplified expression should have the same data type as the original"
80 );
81
82 Ok(rewritten)
83 })?;
84
85 if !result.transformed {
86 return Ok(result.data);
87 }
88 current_expr = result.data;
89 }
90 Ok(current_expr)
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97 use crate::expressions::{
98 BinaryExpr, CastExpr, Literal, NotExpr, TryCastExpr, col, in_list, lit,
99 };
100 use arrow::datatypes::{DataType, Field};
101 use datafusion_common::ScalarValue;
102 use datafusion_expr::Operator;
103
104 fn test_schema() -> Schema {
105 Schema::new(vec![
106 Field::new("c1", DataType::Int32, false),
107 Field::new("c2", DataType::Int64, false),
108 Field::new("c3", DataType::Utf8, false),
109 ])
110 }
111
112 fn not_test_schema() -> Schema {
113 Schema::new(vec![
114 Field::new("a", DataType::Boolean, false),
115 Field::new("b", DataType::Boolean, false),
116 Field::new("c", DataType::Int32, false),
117 ])
118 }
119
120 fn as_literal(expr: &Arc<dyn PhysicalExpr>) -> &Literal {
122 expr.downcast_ref::<Literal>()
123 .unwrap_or_else(|| panic!("Expected Literal, got: {expr}"))
124 }
125
126 fn as_binary(expr: &Arc<dyn PhysicalExpr>) -> &BinaryExpr {
128 expr.downcast_ref::<BinaryExpr>()
129 .unwrap_or_else(|| panic!("Expected BinaryExpr, got: {expr}"))
130 }
131
132 fn assert_not_simplify(
134 simplifier: &PhysicalExprSimplifier,
135 input: Arc<dyn PhysicalExpr>,
136 expected: Arc<dyn PhysicalExpr>,
137 ) {
138 let result = simplifier.simplify(Arc::clone(&input)).unwrap();
139 assert_eq!(
140 &result, &expected,
141 "Simplification should transform:\n input: {input}\n to: {expected}\n got: {result}"
142 );
143 }
144
145 #[test]
146 fn test_simplify() {
147 let schema = test_schema();
148 let simplifier = PhysicalExprSimplifier::new(&schema);
149
150 let column_expr = col("c2", &schema).unwrap();
152 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int32, None));
153 let literal_expr = lit(ScalarValue::Int32(Some(99)));
154 let binary_expr =
155 Arc::new(BinaryExpr::new(cast_expr, Operator::NotEq, literal_expr));
156
157 let optimized = simplifier.simplify(binary_expr).unwrap();
159
160 let optimized_binary = as_binary(&optimized);
161
162 let left_expr = optimized_binary.left();
164 assert!(
165 left_expr.downcast_ref::<CastExpr>().is_none()
166 && left_expr.downcast_ref::<TryCastExpr>().is_none()
167 );
168 let right_literal = as_literal(optimized_binary.right());
169 assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(99)));
170 }
171
172 #[test]
173 fn test_nested_expression_simplification() {
174 let schema = test_schema();
175 let simplifier = PhysicalExprSimplifier::new(&schema);
176
177 let c1_expr = col("c1", &schema).unwrap();
179 let c1_cast = Arc::new(CastExpr::new(c1_expr, DataType::Int64, None));
180 let c1_literal = lit(ScalarValue::Int64(Some(5)));
181 let c1_binary = Arc::new(BinaryExpr::new(c1_cast, Operator::Gt, c1_literal));
182
183 let c2_expr = col("c2", &schema).unwrap();
184 let c2_cast = Arc::new(CastExpr::new(c2_expr, DataType::Int32, None));
185 let c2_literal = lit(ScalarValue::Int32(Some(10)));
186 let c2_binary = Arc::new(BinaryExpr::new(c2_cast, Operator::LtEq, c2_literal));
187
188 let or_expr = Arc::new(BinaryExpr::new(c1_binary, Operator::Or, c2_binary));
189
190 let optimized = simplifier.simplify(or_expr).unwrap();
192
193 let or_binary = as_binary(&optimized);
194
195 let left_binary = as_binary(or_binary.left());
197 let left_left_expr = left_binary.left();
198 assert!(
199 left_left_expr.downcast_ref::<CastExpr>().is_none()
200 && left_left_expr.downcast_ref::<TryCastExpr>().is_none()
201 );
202 let left_literal = as_literal(left_binary.right());
203 assert_eq!(left_literal.value(), &ScalarValue::Int32(Some(5)));
204
205 let right_binary = as_binary(or_binary.right());
207 let right_left_expr = right_binary.left();
208 assert!(
209 right_left_expr.downcast_ref::<CastExpr>().is_none()
210 && right_left_expr.downcast_ref::<TryCastExpr>().is_none()
211 );
212 let right_literal = as_literal(right_binary.right());
213 assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(10)));
214 }
215
216 #[test]
217 fn test_double_negation_elimination() -> Result<()> {
218 let schema = not_test_schema();
219 let simplifier = PhysicalExprSimplifier::new(&schema);
220
221 let inner_expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
223 col("c", &schema)?,
224 Operator::Gt,
225 lit(ScalarValue::Int32(Some(5))),
226 ));
227 let inner_not = Arc::new(NotExpr::new(Arc::clone(&inner_expr)));
228 let double_not: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(inner_not));
229
230 let expected = inner_expr;
231 assert_not_simplify(&simplifier, double_not, expected);
232 Ok(())
233 }
234
235 #[test]
236 fn test_not_literal() -> Result<()> {
237 let schema = not_test_schema();
238 let simplifier = PhysicalExprSimplifier::new(&schema);
239
240 let not_true = Arc::new(NotExpr::new(lit(ScalarValue::Boolean(Some(true)))));
242 let expected = lit(ScalarValue::Boolean(Some(false)));
243 assert_not_simplify(&simplifier, not_true, expected);
244
245 let not_false = Arc::new(NotExpr::new(lit(ScalarValue::Boolean(Some(false)))));
247 let expected = lit(ScalarValue::Boolean(Some(true)));
248 assert_not_simplify(&simplifier, not_false, expected);
249
250 Ok(())
251 }
252
253 #[test]
254 fn test_negate_comparison() -> Result<()> {
255 let schema = not_test_schema();
256 let simplifier = PhysicalExprSimplifier::new(&schema);
257
258 let not_eq = Arc::new(NotExpr::new(Arc::new(BinaryExpr::new(
260 col("c", &schema)?,
261 Operator::Eq,
262 lit(ScalarValue::Int32(Some(5))),
263 ))));
264 let expected = Arc::new(BinaryExpr::new(
265 col("c", &schema)?,
266 Operator::NotEq,
267 lit(ScalarValue::Int32(Some(5))),
268 ));
269 assert_not_simplify(&simplifier, not_eq, expected);
270
271 Ok(())
272 }
273
274 #[test]
275 fn test_demorgans_law_and() -> Result<()> {
276 let schema = not_test_schema();
277 let simplifier = PhysicalExprSimplifier::new(&schema);
278
279 let and_expr = Arc::new(BinaryExpr::new(
281 col("a", &schema)?,
282 Operator::And,
283 col("b", &schema)?,
284 ));
285 let not_and: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(and_expr));
286
287 let expected: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
288 Arc::new(NotExpr::new(col("a", &schema)?)),
289 Operator::Or,
290 Arc::new(NotExpr::new(col("b", &schema)?)),
291 ));
292 assert_not_simplify(&simplifier, not_and, expected);
293
294 Ok(())
295 }
296
297 #[test]
298 fn test_demorgans_law_or() -> Result<()> {
299 let schema = not_test_schema();
300 let simplifier = PhysicalExprSimplifier::new(&schema);
301
302 let or_expr = Arc::new(BinaryExpr::new(
304 col("a", &schema)?,
305 Operator::Or,
306 col("b", &schema)?,
307 ));
308 let not_or: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(or_expr));
309
310 let expected: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
311 Arc::new(NotExpr::new(col("a", &schema)?)),
312 Operator::And,
313 Arc::new(NotExpr::new(col("b", &schema)?)),
314 ));
315 assert_not_simplify(&simplifier, not_or, expected);
316
317 Ok(())
318 }
319
320 #[test]
321 fn test_demorgans_with_comparison_simplification() -> Result<()> {
322 let schema = not_test_schema();
323 let simplifier = PhysicalExprSimplifier::new(&schema);
324
325 let eq1 = Arc::new(BinaryExpr::new(
327 col("c", &schema)?,
328 Operator::Eq,
329 lit(ScalarValue::Int32(Some(1))),
330 ));
331 let eq2 = Arc::new(BinaryExpr::new(
332 col("c", &schema)?,
333 Operator::Eq,
334 lit(ScalarValue::Int32(Some(2))),
335 ));
336 let and_expr = Arc::new(BinaryExpr::new(eq1, Operator::And, eq2));
337 let not_and: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(and_expr));
338
339 let expected: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
340 Arc::new(BinaryExpr::new(
341 col("c", &schema)?,
342 Operator::NotEq,
343 lit(ScalarValue::Int32(Some(1))),
344 )),
345 Operator::Or,
346 Arc::new(BinaryExpr::new(
347 col("c", &schema)?,
348 Operator::NotEq,
349 lit(ScalarValue::Int32(Some(2))),
350 )),
351 ));
352 assert_not_simplify(&simplifier, not_and, expected);
353
354 Ok(())
355 }
356
357 #[test]
358 fn test_not_of_not_and_not() -> Result<()> {
359 let schema = not_test_schema();
360 let simplifier = PhysicalExprSimplifier::new(&schema);
361
362 let not_a = Arc::new(NotExpr::new(col("a", &schema)?));
364 let not_b = Arc::new(NotExpr::new(col("b", &schema)?));
365 let and_expr = Arc::new(BinaryExpr::new(not_a, Operator::And, not_b));
366 let not_and: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(and_expr));
367
368 let expected: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
369 col("a", &schema)?,
370 Operator::Or,
371 col("b", &schema)?,
372 ));
373 assert_not_simplify(&simplifier, not_and, expected);
374
375 Ok(())
376 }
377
378 #[test]
379 fn test_not_in_list() -> Result<()> {
380 let schema = not_test_schema();
381 let simplifier = PhysicalExprSimplifier::new(&schema);
382
383 let list = vec![
385 lit(ScalarValue::Int32(Some(1))),
386 lit(ScalarValue::Int32(Some(2))),
387 lit(ScalarValue::Int32(Some(3))),
388 ];
389 let in_list_expr = in_list(col("c", &schema)?, list.clone(), &false, &schema)?;
390 let not_in: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(in_list_expr));
391
392 let expected = in_list(col("c", &schema)?, list, &true, &schema)?;
393 assert_not_simplify(&simplifier, not_in, expected);
394
395 Ok(())
396 }
397
398 #[test]
399 fn test_not_not_in_list() -> Result<()> {
400 let schema = not_test_schema();
401 let simplifier = PhysicalExprSimplifier::new(&schema);
402
403 let list = vec![
405 lit(ScalarValue::Int32(Some(1))),
406 lit(ScalarValue::Int32(Some(2))),
407 lit(ScalarValue::Int32(Some(3))),
408 ];
409 let not_in_list_expr = in_list(col("c", &schema)?, list.clone(), &true, &schema)?;
410 let not_not_in: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(not_in_list_expr));
411
412 let expected = in_list(col("c", &schema)?, list, &false, &schema)?;
413 assert_not_simplify(&simplifier, not_not_in, expected);
414
415 Ok(())
416 }
417
418 #[test]
419 fn test_double_not_in_list() -> Result<()> {
420 let schema = not_test_schema();
421 let simplifier = PhysicalExprSimplifier::new(&schema);
422
423 let list = vec![
425 lit(ScalarValue::Int32(Some(1))),
426 lit(ScalarValue::Int32(Some(2))),
427 lit(ScalarValue::Int32(Some(3))),
428 ];
429 let in_list_expr = in_list(col("c", &schema)?, list.clone(), &false, &schema)?;
430 let not_in = Arc::new(NotExpr::new(in_list_expr));
431 let double_not: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(not_in));
432
433 let expected = in_list(col("c", &schema)?, list, &false, &schema)?;
434 assert_not_simplify(&simplifier, double_not, expected);
435
436 Ok(())
437 }
438
439 #[test]
440 fn test_deeply_nested_not() -> Result<()> {
441 let schema = not_test_schema();
442 let simplifier = PhysicalExprSimplifier::new(&schema);
443
444 let inner_expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
449 col("c", &schema)?,
450 Operator::Gt,
451 lit(ScalarValue::Int32(Some(5))),
452 ));
453
454 let mut expr = Arc::clone(&inner_expr);
455 for _ in 0..200 {
457 expr = Arc::new(NotExpr::new(expr));
458 }
459
460 let expected = inner_expr;
462 assert_not_simplify(&simplifier, Arc::clone(&expr), expected);
463
464 while let Some(not_expr) = expr.downcast_ref::<NotExpr>() {
469 let child = Arc::clone(not_expr.arg());
472
473 expr = child;
480 }
481
482 Ok(())
483 }
484
485 #[test]
486 fn test_simplify_literal_binary_expr() {
487 let schema = Schema::empty();
488 let simplifier = PhysicalExprSimplifier::new(&schema);
489
490 let expr: Arc<dyn PhysicalExpr> =
492 Arc::new(BinaryExpr::new(lit(1i32), Operator::Plus, lit(2i32)));
493 let result = simplifier.simplify(expr).unwrap();
494 let literal = as_literal(&result);
495 assert_eq!(literal.value(), &ScalarValue::Int32(Some(3)));
496 }
497
498 #[test]
499 fn test_simplify_literal_comparison() {
500 let schema = Schema::empty();
501 let simplifier = PhysicalExprSimplifier::new(&schema);
502
503 let expr: Arc<dyn PhysicalExpr> =
505 Arc::new(BinaryExpr::new(lit(5i32), Operator::Gt, lit(3i32)));
506 let result = simplifier.simplify(expr).unwrap();
507 let literal = as_literal(&result);
508 assert_eq!(literal.value(), &ScalarValue::Boolean(Some(true)));
509
510 let expr: Arc<dyn PhysicalExpr> =
512 Arc::new(BinaryExpr::new(lit(2i32), Operator::Gt, lit(3i32)));
513 let result = simplifier.simplify(expr).unwrap();
514 let literal = as_literal(&result);
515 assert_eq!(literal.value(), &ScalarValue::Boolean(Some(false)));
516 }
517
518 #[test]
519 fn test_simplify_nested_literal_expr() {
520 let schema = Schema::empty();
521 let simplifier = PhysicalExprSimplifier::new(&schema);
522
523 let inner: Arc<dyn PhysicalExpr> =
525 Arc::new(BinaryExpr::new(lit(1i32), Operator::Plus, lit(2i32)));
526 let expr: Arc<dyn PhysicalExpr> =
527 Arc::new(BinaryExpr::new(inner, Operator::Multiply, lit(3i32)));
528 let result = simplifier.simplify(expr).unwrap();
529 let literal = as_literal(&result);
530 assert_eq!(literal.value(), &ScalarValue::Int32(Some(9)));
531 }
532
533 #[test]
534 fn test_simplify_deeply_nested_literals() {
535 let schema = Schema::empty();
536 let simplifier = PhysicalExprSimplifier::new(&schema);
537
538 let left: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
540 Arc::new(BinaryExpr::new(lit(1i32), Operator::Plus, lit(2i32))),
541 Operator::Multiply,
542 lit(3i32),
543 ));
544 let right: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
545 Arc::new(BinaryExpr::new(lit(4i32), Operator::Minus, lit(1i32))),
546 Operator::Multiply,
547 lit(2i32),
548 ));
549 let expr: Arc<dyn PhysicalExpr> =
550 Arc::new(BinaryExpr::new(left, Operator::Plus, right));
551 let result = simplifier.simplify(expr).unwrap();
552 let literal = as_literal(&result);
553 assert_eq!(literal.value(), &ScalarValue::Int32(Some(15)));
554 }
555
556 #[test]
557 fn test_no_simplify_with_column() {
558 let schema = test_schema();
559 let simplifier = PhysicalExprSimplifier::new(&schema);
560
561 let expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
563 col("c1", &schema).unwrap(),
564 Operator::Plus,
565 lit(2i32),
566 ));
567 let result = simplifier.simplify(expr).unwrap();
568 assert!(result.downcast_ref::<BinaryExpr>().is_some());
570 }
571
572 #[test]
573 fn test_partial_simplify_with_column() {
574 let schema = test_schema();
575 let simplifier = PhysicalExprSimplifier::new(&schema);
576
577 let literal_part: Arc<dyn PhysicalExpr> =
579 Arc::new(BinaryExpr::new(lit(1i32), Operator::Plus, lit(2i32)));
580 let expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
581 literal_part,
582 Operator::Plus,
583 col("c1", &schema).unwrap(),
584 ));
585 let result = simplifier.simplify(expr).unwrap();
586
587 let binary = as_binary(&result);
589 let left_literal = as_literal(binary.left());
590 assert_eq!(left_literal.value(), &ScalarValue::Int32(Some(3)));
591 }
592
593 #[test]
603 fn test_no_simplify_opaque_leaf_expr() {
604 use arrow::array::ArrayRef;
605 use arrow::array::Int32Array;
606 use arrow::record_batch::RecordBatch;
607 use datafusion_expr_common::columnar_value::ColumnarValue;
608 use datafusion_physical_expr_common::physical_expr::PhysicalExpr as PhysicalExprTrait;
609 use std::fmt;
610
611 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
612 struct OpaqueLeaf;
613
614 impl fmt::Display for OpaqueLeaf {
615 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
616 write!(f, "OpaqueLeaf")
617 }
618 }
619
620 impl PhysicalExprTrait for OpaqueLeaf {
621 fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
622 Ok(DataType::Int32)
623 }
624 fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
625 Ok(true)
626 }
627 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
628 let arr: ArrayRef = Arc::new(Int32Array::from(vec![0; batch.num_rows()]));
633 Ok(ColumnarValue::Array(arr))
634 }
635 fn children(&self) -> Vec<&Arc<dyn PhysicalExprTrait>> {
636 vec![]
637 }
638 fn with_new_children(
639 self: Arc<Self>,
640 _children: Vec<Arc<dyn PhysicalExprTrait>>,
641 ) -> Result<Arc<dyn PhysicalExprTrait>> {
642 Ok(self)
643 }
644 fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
645 write!(f, "OpaqueLeaf")
646 }
647 }
648
649 let schema = Schema::empty();
650 let simplifier = PhysicalExprSimplifier::new(&schema);
651
652 let opaque: Arc<dyn PhysicalExpr> = Arc::new(OpaqueLeaf);
653 let result = simplifier.simplify(Arc::clone(&opaque)).unwrap();
654
655 assert!(
656 result.downcast_ref::<Literal>().is_none(),
657 "opaque leaf must not be rewritten to a Literal, got: {result}"
658 );
659 assert_eq!(&result, &opaque);
660 }
661
662 #[test]
663 fn test_simplify_literal_string_concat() {
664 let schema = Schema::empty();
665 let simplifier = PhysicalExprSimplifier::new(&schema);
666
667 let expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
669 lit("hello"),
670 Operator::StringConcat,
671 lit(" world"),
672 ));
673 let result = simplifier.simplify(expr).unwrap();
674 let literal = as_literal(&result);
675 assert_eq!(
676 literal.value(),
677 &ScalarValue::Utf8(Some("hello world".to_string()))
678 );
679 }
680}