1use arrow::datatypes::Schema;
21use datafusion_common::{Result, tree_node::TreeNode};
22use std::sync::Arc;
23
24use crate::{
25 PhysicalExpr,
26 simplifier::{
27 const_evaluator::create_dummy_batch, unwrap_cast::unwrap_cast_in_comparison,
28 },
29};
30
31pub mod const_evaluator;
32pub mod not;
33pub mod unwrap_cast;
34
35const MAX_LOOP_COUNT: usize = 5;
36
37pub struct PhysicalExprSimplifier<'a> {
43 schema: &'a Schema,
44}
45
46impl<'a> PhysicalExprSimplifier<'a> {
47 pub fn new(schema: &'a Schema) -> Self {
49 Self { schema }
50 }
51
52 pub fn simplify(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
54 let mut current_expr = expr;
55 let mut count = 0;
56 let schema = self.schema;
57
58 let batch = create_dummy_batch()?;
59
60 while count < MAX_LOOP_COUNT {
61 count += 1;
62 let result = current_expr.transform(|node| {
63 #[cfg(debug_assertions)]
64 let original_type = node.data_type(schema).unwrap();
65
66 #[expect(deprecated, reason = "`simplify_not_expr` is marked as deprecated until it's made private.")]
69 let rewritten = not::simplify_not_expr(node, schema)?
70 .transform_data(|node| unwrap_cast_in_comparison(node, schema))?
71 .transform_data(|node| {
72 const_evaluator::simplify_const_expr_immediate(node, &batch)
73 })?;
74
75 #[cfg(debug_assertions)]
76 assert_eq!(
77 rewritten.data.data_type(schema).unwrap(),
78 original_type,
79 "Simplified expression should have the same data type as the original"
80 );
81
82 Ok(rewritten)
83 })?;
84
85 if !result.transformed {
86 return Ok(result.data);
87 }
88 current_expr = result.data;
89 }
90 Ok(current_expr)
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97 use crate::expressions::{
98 BinaryExpr, CastExpr, Literal, NotExpr, TryCastExpr, col, in_list, lit,
99 };
100 use arrow::datatypes::{DataType, Field, Schema};
101 use datafusion_common::ScalarValue;
102 use datafusion_expr::Operator;
103
104 fn test_schema() -> Schema {
105 Schema::new(vec![
106 Field::new("c1", DataType::Int32, false),
107 Field::new("c2", DataType::Int64, false),
108 Field::new("c3", DataType::Utf8, false),
109 ])
110 }
111
112 fn not_test_schema() -> Schema {
113 Schema::new(vec![
114 Field::new("a", DataType::Boolean, false),
115 Field::new("b", DataType::Boolean, false),
116 Field::new("c", DataType::Int32, false),
117 ])
118 }
119
120 fn as_literal(expr: &Arc<dyn PhysicalExpr>) -> &Literal {
122 expr.as_any()
123 .downcast_ref::<Literal>()
124 .unwrap_or_else(|| panic!("Expected Literal, got: {expr}"))
125 }
126
127 fn as_binary(expr: &Arc<dyn PhysicalExpr>) -> &BinaryExpr {
129 expr.as_any()
130 .downcast_ref::<BinaryExpr>()
131 .unwrap_or_else(|| panic!("Expected BinaryExpr, got: {expr}"))
132 }
133
134 fn assert_not_simplify(
136 simplifier: &PhysicalExprSimplifier,
137 input: Arc<dyn PhysicalExpr>,
138 expected: Arc<dyn PhysicalExpr>,
139 ) {
140 let result = simplifier.simplify(Arc::clone(&input)).unwrap();
141 assert_eq!(
142 &result, &expected,
143 "Simplification should transform:\n input: {input}\n to: {expected}\n got: {result}"
144 );
145 }
146
147 #[test]
148 fn test_simplify() {
149 let schema = test_schema();
150 let simplifier = PhysicalExprSimplifier::new(&schema);
151
152 let column_expr = col("c2", &schema).unwrap();
154 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int32, None));
155 let literal_expr = lit(ScalarValue::Int32(Some(99)));
156 let binary_expr =
157 Arc::new(BinaryExpr::new(cast_expr, Operator::NotEq, literal_expr));
158
159 let optimized = simplifier.simplify(binary_expr).unwrap();
161
162 let optimized_binary = as_binary(&optimized);
163
164 let left_expr = optimized_binary.left();
166 assert!(
167 left_expr.as_any().downcast_ref::<CastExpr>().is_none()
168 && left_expr.as_any().downcast_ref::<TryCastExpr>().is_none()
169 );
170 let right_literal = as_literal(optimized_binary.right());
171 assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(99)));
172 }
173
174 #[test]
175 fn test_nested_expression_simplification() {
176 let schema = test_schema();
177 let simplifier = PhysicalExprSimplifier::new(&schema);
178
179 let c1_expr = col("c1", &schema).unwrap();
181 let c1_cast = Arc::new(CastExpr::new(c1_expr, DataType::Int64, None));
182 let c1_literal = lit(ScalarValue::Int64(Some(5)));
183 let c1_binary = Arc::new(BinaryExpr::new(c1_cast, Operator::Gt, c1_literal));
184
185 let c2_expr = col("c2", &schema).unwrap();
186 let c2_cast = Arc::new(CastExpr::new(c2_expr, DataType::Int32, None));
187 let c2_literal = lit(ScalarValue::Int32(Some(10)));
188 let c2_binary = Arc::new(BinaryExpr::new(c2_cast, Operator::LtEq, c2_literal));
189
190 let or_expr = Arc::new(BinaryExpr::new(c1_binary, Operator::Or, c2_binary));
191
192 let optimized = simplifier.simplify(or_expr).unwrap();
194
195 let or_binary = as_binary(&optimized);
196
197 let left_binary = as_binary(or_binary.left());
199 let left_left_expr = left_binary.left();
200 assert!(
201 left_left_expr.as_any().downcast_ref::<CastExpr>().is_none()
202 && left_left_expr
203 .as_any()
204 .downcast_ref::<TryCastExpr>()
205 .is_none()
206 );
207 let left_literal = as_literal(left_binary.right());
208 assert_eq!(left_literal.value(), &ScalarValue::Int32(Some(5)));
209
210 let right_binary = as_binary(or_binary.right());
212 let right_left_expr = right_binary.left();
213 assert!(
214 right_left_expr
215 .as_any()
216 .downcast_ref::<CastExpr>()
217 .is_none()
218 && right_left_expr
219 .as_any()
220 .downcast_ref::<TryCastExpr>()
221 .is_none()
222 );
223 let right_literal = as_literal(right_binary.right());
224 assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(10)));
225 }
226
227 #[test]
228 fn test_double_negation_elimination() -> Result<()> {
229 let schema = not_test_schema();
230 let simplifier = PhysicalExprSimplifier::new(&schema);
231
232 let inner_expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
234 col("c", &schema)?,
235 Operator::Gt,
236 lit(ScalarValue::Int32(Some(5))),
237 ));
238 let inner_not = Arc::new(NotExpr::new(Arc::clone(&inner_expr)));
239 let double_not: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(inner_not));
240
241 let expected = inner_expr;
242 assert_not_simplify(&simplifier, double_not, expected);
243 Ok(())
244 }
245
246 #[test]
247 fn test_not_literal() -> Result<()> {
248 let schema = not_test_schema();
249 let simplifier = PhysicalExprSimplifier::new(&schema);
250
251 let not_true = Arc::new(NotExpr::new(lit(ScalarValue::Boolean(Some(true)))));
253 let expected = lit(ScalarValue::Boolean(Some(false)));
254 assert_not_simplify(&simplifier, not_true, expected);
255
256 let not_false = Arc::new(NotExpr::new(lit(ScalarValue::Boolean(Some(false)))));
258 let expected = lit(ScalarValue::Boolean(Some(true)));
259 assert_not_simplify(&simplifier, not_false, expected);
260
261 Ok(())
262 }
263
264 #[test]
265 fn test_negate_comparison() -> Result<()> {
266 let schema = not_test_schema();
267 let simplifier = PhysicalExprSimplifier::new(&schema);
268
269 let not_eq = Arc::new(NotExpr::new(Arc::new(BinaryExpr::new(
271 col("c", &schema)?,
272 Operator::Eq,
273 lit(ScalarValue::Int32(Some(5))),
274 ))));
275 let expected = Arc::new(BinaryExpr::new(
276 col("c", &schema)?,
277 Operator::NotEq,
278 lit(ScalarValue::Int32(Some(5))),
279 ));
280 assert_not_simplify(&simplifier, not_eq, expected);
281
282 Ok(())
283 }
284
285 #[test]
286 fn test_demorgans_law_and() -> Result<()> {
287 let schema = not_test_schema();
288 let simplifier = PhysicalExprSimplifier::new(&schema);
289
290 let and_expr = Arc::new(BinaryExpr::new(
292 col("a", &schema)?,
293 Operator::And,
294 col("b", &schema)?,
295 ));
296 let not_and: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(and_expr));
297
298 let expected: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
299 Arc::new(NotExpr::new(col("a", &schema)?)),
300 Operator::Or,
301 Arc::new(NotExpr::new(col("b", &schema)?)),
302 ));
303 assert_not_simplify(&simplifier, not_and, expected);
304
305 Ok(())
306 }
307
308 #[test]
309 fn test_demorgans_law_or() -> Result<()> {
310 let schema = not_test_schema();
311 let simplifier = PhysicalExprSimplifier::new(&schema);
312
313 let or_expr = Arc::new(BinaryExpr::new(
315 col("a", &schema)?,
316 Operator::Or,
317 col("b", &schema)?,
318 ));
319 let not_or: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(or_expr));
320
321 let expected: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
322 Arc::new(NotExpr::new(col("a", &schema)?)),
323 Operator::And,
324 Arc::new(NotExpr::new(col("b", &schema)?)),
325 ));
326 assert_not_simplify(&simplifier, not_or, expected);
327
328 Ok(())
329 }
330
331 #[test]
332 fn test_demorgans_with_comparison_simplification() -> Result<()> {
333 let schema = not_test_schema();
334 let simplifier = PhysicalExprSimplifier::new(&schema);
335
336 let eq1 = Arc::new(BinaryExpr::new(
338 col("c", &schema)?,
339 Operator::Eq,
340 lit(ScalarValue::Int32(Some(1))),
341 ));
342 let eq2 = Arc::new(BinaryExpr::new(
343 col("c", &schema)?,
344 Operator::Eq,
345 lit(ScalarValue::Int32(Some(2))),
346 ));
347 let and_expr = Arc::new(BinaryExpr::new(eq1, Operator::And, eq2));
348 let not_and: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(and_expr));
349
350 let expected: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
351 Arc::new(BinaryExpr::new(
352 col("c", &schema)?,
353 Operator::NotEq,
354 lit(ScalarValue::Int32(Some(1))),
355 )),
356 Operator::Or,
357 Arc::new(BinaryExpr::new(
358 col("c", &schema)?,
359 Operator::NotEq,
360 lit(ScalarValue::Int32(Some(2))),
361 )),
362 ));
363 assert_not_simplify(&simplifier, not_and, expected);
364
365 Ok(())
366 }
367
368 #[test]
369 fn test_not_of_not_and_not() -> Result<()> {
370 let schema = not_test_schema();
371 let simplifier = PhysicalExprSimplifier::new(&schema);
372
373 let not_a = Arc::new(NotExpr::new(col("a", &schema)?));
375 let not_b = Arc::new(NotExpr::new(col("b", &schema)?));
376 let and_expr = Arc::new(BinaryExpr::new(not_a, Operator::And, not_b));
377 let not_and: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(and_expr));
378
379 let expected: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
380 col("a", &schema)?,
381 Operator::Or,
382 col("b", &schema)?,
383 ));
384 assert_not_simplify(&simplifier, not_and, expected);
385
386 Ok(())
387 }
388
389 #[test]
390 fn test_not_in_list() -> Result<()> {
391 let schema = not_test_schema();
392 let simplifier = PhysicalExprSimplifier::new(&schema);
393
394 let list = vec![
396 lit(ScalarValue::Int32(Some(1))),
397 lit(ScalarValue::Int32(Some(2))),
398 lit(ScalarValue::Int32(Some(3))),
399 ];
400 let in_list_expr = in_list(col("c", &schema)?, list.clone(), &false, &schema)?;
401 let not_in: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(in_list_expr));
402
403 let expected = in_list(col("c", &schema)?, list, &true, &schema)?;
404 assert_not_simplify(&simplifier, not_in, expected);
405
406 Ok(())
407 }
408
409 #[test]
410 fn test_not_not_in_list() -> Result<()> {
411 let schema = not_test_schema();
412 let simplifier = PhysicalExprSimplifier::new(&schema);
413
414 let list = vec![
416 lit(ScalarValue::Int32(Some(1))),
417 lit(ScalarValue::Int32(Some(2))),
418 lit(ScalarValue::Int32(Some(3))),
419 ];
420 let not_in_list_expr = in_list(col("c", &schema)?, list.clone(), &true, &schema)?;
421 let not_not_in: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(not_in_list_expr));
422
423 let expected = in_list(col("c", &schema)?, list, &false, &schema)?;
424 assert_not_simplify(&simplifier, not_not_in, expected);
425
426 Ok(())
427 }
428
429 #[test]
430 fn test_double_not_in_list() -> Result<()> {
431 let schema = not_test_schema();
432 let simplifier = PhysicalExprSimplifier::new(&schema);
433
434 let list = vec![
436 lit(ScalarValue::Int32(Some(1))),
437 lit(ScalarValue::Int32(Some(2))),
438 lit(ScalarValue::Int32(Some(3))),
439 ];
440 let in_list_expr = in_list(col("c", &schema)?, list.clone(), &false, &schema)?;
441 let not_in = Arc::new(NotExpr::new(in_list_expr));
442 let double_not: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(not_in));
443
444 let expected = in_list(col("c", &schema)?, list, &false, &schema)?;
445 assert_not_simplify(&simplifier, double_not, expected);
446
447 Ok(())
448 }
449
450 #[test]
451 fn test_deeply_nested_not() -> Result<()> {
452 let schema = not_test_schema();
453 let simplifier = PhysicalExprSimplifier::new(&schema);
454
455 let inner_expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
460 col("c", &schema)?,
461 Operator::Gt,
462 lit(ScalarValue::Int32(Some(5))),
463 ));
464
465 let mut expr = Arc::clone(&inner_expr);
466 for _ in 0..200 {
468 expr = Arc::new(NotExpr::new(expr));
469 }
470
471 let expected = inner_expr;
473 assert_not_simplify(&simplifier, Arc::clone(&expr), expected);
474
475 while let Some(not_expr) = expr.as_any().downcast_ref::<NotExpr>() {
480 let child = Arc::clone(not_expr.arg());
483
484 expr = child;
491 }
492
493 Ok(())
494 }
495
496 #[test]
497 fn test_simplify_literal_binary_expr() {
498 let schema = Schema::empty();
499 let simplifier = PhysicalExprSimplifier::new(&schema);
500
501 let expr: Arc<dyn PhysicalExpr> =
503 Arc::new(BinaryExpr::new(lit(1i32), Operator::Plus, lit(2i32)));
504 let result = simplifier.simplify(expr).unwrap();
505 let literal = as_literal(&result);
506 assert_eq!(literal.value(), &ScalarValue::Int32(Some(3)));
507 }
508
509 #[test]
510 fn test_simplify_literal_comparison() {
511 let schema = Schema::empty();
512 let simplifier = PhysicalExprSimplifier::new(&schema);
513
514 let expr: Arc<dyn PhysicalExpr> =
516 Arc::new(BinaryExpr::new(lit(5i32), Operator::Gt, lit(3i32)));
517 let result = simplifier.simplify(expr).unwrap();
518 let literal = as_literal(&result);
519 assert_eq!(literal.value(), &ScalarValue::Boolean(Some(true)));
520
521 let expr: Arc<dyn PhysicalExpr> =
523 Arc::new(BinaryExpr::new(lit(2i32), Operator::Gt, lit(3i32)));
524 let result = simplifier.simplify(expr).unwrap();
525 let literal = as_literal(&result);
526 assert_eq!(literal.value(), &ScalarValue::Boolean(Some(false)));
527 }
528
529 #[test]
530 fn test_simplify_nested_literal_expr() {
531 let schema = Schema::empty();
532 let simplifier = PhysicalExprSimplifier::new(&schema);
533
534 let inner: Arc<dyn PhysicalExpr> =
536 Arc::new(BinaryExpr::new(lit(1i32), Operator::Plus, lit(2i32)));
537 let expr: Arc<dyn PhysicalExpr> =
538 Arc::new(BinaryExpr::new(inner, Operator::Multiply, lit(3i32)));
539 let result = simplifier.simplify(expr).unwrap();
540 let literal = as_literal(&result);
541 assert_eq!(literal.value(), &ScalarValue::Int32(Some(9)));
542 }
543
544 #[test]
545 fn test_simplify_deeply_nested_literals() {
546 let schema = Schema::empty();
547 let simplifier = PhysicalExprSimplifier::new(&schema);
548
549 let left: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
551 Arc::new(BinaryExpr::new(lit(1i32), Operator::Plus, lit(2i32))),
552 Operator::Multiply,
553 lit(3i32),
554 ));
555 let right: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
556 Arc::new(BinaryExpr::new(lit(4i32), Operator::Minus, lit(1i32))),
557 Operator::Multiply,
558 lit(2i32),
559 ));
560 let expr: Arc<dyn PhysicalExpr> =
561 Arc::new(BinaryExpr::new(left, Operator::Plus, right));
562 let result = simplifier.simplify(expr).unwrap();
563 let literal = as_literal(&result);
564 assert_eq!(literal.value(), &ScalarValue::Int32(Some(15)));
565 }
566
567 #[test]
568 fn test_no_simplify_with_column() {
569 let schema = test_schema();
570 let simplifier = PhysicalExprSimplifier::new(&schema);
571
572 let expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
574 col("c1", &schema).unwrap(),
575 Operator::Plus,
576 lit(2i32),
577 ));
578 let result = simplifier.simplify(expr).unwrap();
579 assert!(result.as_any().downcast_ref::<BinaryExpr>().is_some());
581 }
582
583 #[test]
584 fn test_partial_simplify_with_column() {
585 let schema = test_schema();
586 let simplifier = PhysicalExprSimplifier::new(&schema);
587
588 let literal_part: Arc<dyn PhysicalExpr> =
590 Arc::new(BinaryExpr::new(lit(1i32), Operator::Plus, lit(2i32)));
591 let expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
592 literal_part,
593 Operator::Plus,
594 col("c1", &schema).unwrap(),
595 ));
596 let result = simplifier.simplify(expr).unwrap();
597
598 let binary = as_binary(&result);
600 let left_literal = as_literal(binary.left());
601 assert_eq!(left_literal.value(), &ScalarValue::Int32(Some(3)));
602 }
603
604 #[test]
605 fn test_simplify_literal_string_concat() {
606 let schema = Schema::empty();
607 let simplifier = PhysicalExprSimplifier::new(&schema);
608
609 let expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
611 lit("hello"),
612 Operator::StringConcat,
613 lit(" world"),
614 ));
615 let result = simplifier.simplify(expr).unwrap();
616 let literal = as_literal(&result);
617 assert_eq!(
618 literal.value(),
619 &ScalarValue::Utf8(Some("hello world".to_string()))
620 );
621 }
622}