1use crate::{Between, BinaryExpr, Expr, expr::InList, lit};
21use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
22use datafusion_common::{DataFusionError, HashMap, Result, ScalarValue};
23use datafusion_expr_common::interval_arithmetic::{Interval, NullableInterval};
24use std::borrow::Cow;
25
26pub struct GuaranteeRewriter<'a> {
30 guarantees: HashMap<&'a Expr, &'a NullableInterval>,
31}
32
33impl<'a> GuaranteeRewriter<'a> {
34 pub fn new(
35 guarantees: impl IntoIterator<Item = &'a (Expr, NullableInterval)>,
36 ) -> Self {
37 Self {
38 guarantees: guarantees.into_iter().map(|(k, v)| (k, v)).collect(),
39 }
40 }
41}
42
43pub fn rewrite_with_guarantees<'a>(
61 expr: Expr,
62 guarantees: impl IntoIterator<Item = &'a (Expr, NullableInterval)>,
63) -> Result<Transformed<Expr>> {
64 let guarantees_map: HashMap<&Expr, &NullableInterval> =
65 guarantees.into_iter().map(|(k, v)| (k, v)).collect();
66 rewrite_with_guarantees_map(expr, &guarantees_map)
67}
68
69pub fn rewrite_with_guarantees_map<'a>(
78 expr: Expr,
79 guarantees: &'a HashMap<&'a Expr, &'a NullableInterval>,
80) -> Result<Transformed<Expr>> {
81 if guarantees.is_empty() {
82 return Ok(Transformed::no(expr));
83 }
84
85 expr.transform_up(|e| rewrite_expr(e, guarantees))
86}
87
88impl TreeNodeRewriter for GuaranteeRewriter<'_> {
89 type Node = Expr;
90
91 fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
92 if self.guarantees.is_empty() {
93 return Ok(Transformed::no(expr));
94 }
95
96 rewrite_expr(expr, &self.guarantees)
97 }
98}
99
100fn rewrite_expr(
101 expr: Expr,
102 guarantees: &HashMap<&Expr, &NullableInterval>,
103) -> Result<Transformed<Expr>> {
104 if let Some(interval) = guarantees.get(&expr)
106 && let Some(value) = interval.single_value()
107 {
108 return Ok(Transformed::yes(lit(value)));
109 }
110
111 let result = match expr {
112 Expr::IsNull(inner) => match guarantees.get(inner.as_ref()) {
113 Some(NullableInterval::Null { .. }) => Transformed::yes(lit(true)),
114 Some(NullableInterval::NotNull { .. }) => Transformed::yes(lit(false)),
115 _ => Transformed::no(Expr::IsNull(inner)),
116 },
117 Expr::IsNotNull(inner) => match guarantees.get(inner.as_ref()) {
118 Some(NullableInterval::Null { .. }) => Transformed::yes(lit(false)),
119 Some(NullableInterval::NotNull { .. }) => Transformed::yes(lit(true)),
120 _ => Transformed::no(Expr::IsNotNull(inner)),
121 },
122 Expr::Between(b) => rewrite_between(b, guarantees)?,
123 Expr::BinaryExpr(b) => rewrite_binary_expr(b, guarantees)?,
124 Expr::InList(i) => rewrite_inlist(i, guarantees)?,
125 expr => Transformed::no(expr),
126 };
127 Ok(result)
128}
129
130fn rewrite_between(
131 between: Between,
132 guarantees: &HashMap<&Expr, &NullableInterval>,
133) -> Result<Transformed<Expr>> {
134 let (Some(expr_interval), Expr::Literal(low, _), Expr::Literal(high, _)) = (
135 guarantees.get(between.expr.as_ref()),
136 between.low.as_ref(),
137 between.high.as_ref(),
138 ) else {
139 return Ok(Transformed::no(Expr::Between(between)));
140 };
141
142 let low = ensure_typed_null(low, high)?;
144 let high = ensure_typed_null(high, &low)?;
145
146 let Ok(between_interval) = Interval::try_new(low, high) else {
147 return Ok(Transformed::no(Expr::Between(between)));
150 };
151
152 if between_interval.lower().is_null() && between_interval.upper().is_null() {
153 return Ok(Transformed::yes(lit(between_interval.lower().clone())));
154 }
155
156 let expr_interval = match expr_interval {
157 NullableInterval::Null { datatype } => {
158 return Ok(Transformed::yes(lit(
160 ScalarValue::try_new_null(datatype).unwrap_or(ScalarValue::Null)
161 )));
162 }
163 NullableInterval::MaybeNull { .. } => {
164 return Ok(Transformed::no(Expr::Between(between)));
166 }
167 NullableInterval::NotNull { values } => values,
168 };
169
170 let result = if between_interval.lower().is_null() {
171 let upper_bound = Interval::from(between_interval.upper().clone());
173 if expr_interval.gt(&upper_bound)?.eq(&Interval::TRUE) {
174 Transformed::yes(lit(between.negated))
176 } else if expr_interval.lt_eq(&upper_bound)?.eq(&Interval::TRUE) {
177 Transformed::yes(lit(ScalarValue::try_new_null(&expr_interval.data_type())
179 .unwrap_or(ScalarValue::Null)))
180 } else {
181 Transformed::no(Expr::Between(between))
183 }
184 } else if between_interval.upper().is_null() {
185 let lower_bound = Interval::from(between_interval.lower().clone());
187 if expr_interval.lt(&lower_bound)?.eq(&Interval::TRUE) {
188 Transformed::yes(lit(between.negated))
190 } else if expr_interval.gt_eq(&lower_bound)?.eq(&Interval::TRUE) {
191 Transformed::yes(lit(ScalarValue::try_new_null(&expr_interval.data_type())
193 .unwrap_or(ScalarValue::Null)))
194 } else {
195 Transformed::no(Expr::Between(between))
197 }
198 } else {
199 let contains = between_interval.contains(expr_interval)?;
200 if contains.eq(&Interval::TRUE) {
201 Transformed::yes(lit(!between.negated))
202 } else if contains.eq(&Interval::FALSE) {
203 Transformed::yes(lit(between.negated))
204 } else {
205 Transformed::no(Expr::Between(between))
206 }
207 };
208 Ok(result)
209}
210
211fn ensure_typed_null(
212 value: &ScalarValue,
213 other: &ScalarValue,
214) -> Result<ScalarValue, DataFusionError> {
215 Ok(
216 if value.data_type().is_null() && !other.data_type().is_null() {
217 ScalarValue::try_new_null(&other.data_type())?
218 } else {
219 value.clone()
220 },
221 )
222}
223
224fn rewrite_binary_expr(
225 binary: BinaryExpr,
226 guarantees: &HashMap<&Expr, &NullableInterval>,
227) -> Result<Transformed<Expr>, DataFusionError> {
228 let left_interval = guarantees
231 .get(binary.left.as_ref())
232 .map(|interval| Cow::Borrowed(*interval))
233 .or_else(|| {
234 if let Expr::Literal(value, _) = binary.left.as_ref() {
235 Some(Cow::Owned(value.clone().into()))
236 } else {
237 None
238 }
239 });
240 let right_interval = guarantees
241 .get(binary.right.as_ref())
242 .map(|interval| Cow::Borrowed(*interval))
243 .or_else(|| {
244 if let Expr::Literal(value, _) = binary.right.as_ref() {
245 Some(Cow::Owned(value.clone().into()))
246 } else {
247 None
248 }
249 });
250
251 if let (Some(left_interval), Some(right_interval)) = (left_interval, right_interval) {
252 let result = left_interval.apply_operator(&binary.op, right_interval.as_ref())?;
253 if result.is_certainly_true() {
254 return Ok(Transformed::yes(lit(true)));
255 } else if result.is_certainly_false() {
256 return Ok(Transformed::yes(lit(false)));
257 }
258 }
259 Ok(Transformed::no(Expr::BinaryExpr(binary)))
260}
261
262fn rewrite_inlist(
263 inlist: InList,
264 guarantees: &HashMap<&Expr, &NullableInterval>,
265) -> Result<Transformed<Expr>, DataFusionError> {
266 let Some(interval) = guarantees.get(inlist.expr.as_ref()) else {
267 return Ok(Transformed::no(Expr::InList(inlist)));
268 };
269
270 let InList {
271 expr,
272 list,
273 negated,
274 } = inlist;
275
276 let list: Vec<Expr> = list
278 .into_iter()
279 .filter_map(|expr| {
280 if let Expr::Literal(item, _) = &expr {
281 match interval.contains(NullableInterval::from(item.clone())) {
282 Ok(interval) if interval.is_certainly_false() => None,
285 Ok(_) => Some(Ok(expr)),
286 Err(e) => Some(Err(e)),
287 }
288 } else {
289 Some(Ok(expr))
290 }
291 })
292 .collect::<Result<_, DataFusionError>>()?;
293
294 Ok(Transformed::yes(Expr::InList(InList {
295 expr,
296 list,
297 negated,
298 })))
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 use crate::{Operator, col};
306 use datafusion_common::tree_node::TransformedResult;
307
308 #[test]
309 fn test_not_null_guarantee() {
310 let guarantees = [
312 (
315 col("x"),
316 NullableInterval::NotNull {
317 values: Interval::make(Some(1), Some(3)).unwrap(),
318 },
319 ),
320 ];
321
322 let is_null_cases = vec![
323 (col("x").is_null(), Some(lit(false))),
325 (col("x").is_not_null(), Some(lit(true))),
327 (col("x").between(lit(0), lit(10)), Some(lit(true))),
329 (col("x").between(lit(1), lit(-2)), None),
331 (
333 col("x").between(lit(ScalarValue::Null), lit(0)),
334 Some(lit(false)),
335 ),
336 (col("x").between(lit(ScalarValue::Null), lit(1)), None),
338 (col("x").between(lit(ScalarValue::Null), lit(2)), None),
340 (
342 col("x").between(lit(ScalarValue::Null), lit(3)),
343 Some(lit(ScalarValue::Int32(None))),
344 ),
345 (
347 col("x").between(lit(ScalarValue::Null), lit(4)),
348 Some(lit(ScalarValue::Int32(None))),
349 ),
350 (
352 col("x").between(lit(0), lit(ScalarValue::Null)),
353 Some(lit(ScalarValue::Int32(None))),
354 ),
355 (
357 col("x").between(lit(1), lit(ScalarValue::Null)),
358 Some(lit(ScalarValue::Int32(None))),
359 ),
360 (col("x").between(lit(2), lit(ScalarValue::Null)), None),
362 (col("x").between(lit(3), lit(ScalarValue::Null)), None),
364 (
366 col("x").between(lit(4), lit(ScalarValue::Null)),
367 Some(lit(false)),
368 ),
369 (
371 col("x").not_between(lit(ScalarValue::Null), lit(0)),
372 Some(lit(true)),
373 ),
374 (col("x").not_between(lit(ScalarValue::Null), lit(1)), None),
376 (col("x").not_between(lit(ScalarValue::Null), lit(2)), None),
378 (
380 col("x").not_between(lit(ScalarValue::Null), lit(3)),
381 Some(lit(ScalarValue::Int32(None))),
382 ),
383 (
385 col("x").not_between(lit(ScalarValue::Null), lit(4)),
386 Some(lit(ScalarValue::Int32(None))),
387 ),
388 (
390 col("x").not_between(lit(0), lit(ScalarValue::Null)),
391 Some(lit(ScalarValue::Int32(None))),
392 ),
393 (
395 col("x").not_between(lit(1), lit(ScalarValue::Null)),
396 Some(lit(ScalarValue::Int32(None))),
397 ),
398 (col("x").not_between(lit(2), lit(ScalarValue::Null)), None),
400 (col("x").not_between(lit(3), lit(ScalarValue::Null)), None),
402 (
404 col("x").not_between(lit(4), lit(ScalarValue::Null)),
405 Some(lit(true)),
406 ),
407 ];
408
409 for case in is_null_cases {
410 let output = rewrite_with_guarantees(case.0.clone(), guarantees.iter())
411 .data()
412 .unwrap();
413 let expected = match case.1 {
414 None => case.0.clone(),
415 Some(expected) => expected,
416 };
417
418 assert_eq!(output, expected, "Failed for {}", case.0);
419 }
420 }
421
422 fn validate_simplified_cases<T>(
423 guarantees: &[(Expr, NullableInterval)],
424 cases: &[(Expr, T)],
425 ) where
426 ScalarValue: From<T>,
427 T: Clone,
428 {
429 for (expr, expected_value) in cases {
430 let output = rewrite_with_guarantees(expr.clone(), guarantees.iter())
431 .data()
432 .unwrap();
433 let expected = lit(ScalarValue::from(expected_value.clone()));
434 assert_eq!(
435 output, expected,
436 "{expr} simplified to {output}, but expected {expected}"
437 );
438 }
439 }
440
441 fn validate_unchanged_cases(guarantees: &[(Expr, NullableInterval)], cases: &[Expr]) {
442 for expr in cases {
443 let output = rewrite_with_guarantees(expr.clone(), guarantees.iter())
444 .data()
445 .unwrap();
446 assert_eq!(
447 &output, expr,
448 "{expr} was simplified to {output}, but expected it to be unchanged"
449 );
450 }
451 }
452
453 #[test]
454 fn test_inequalities_non_null_unbounded() {
455 let guarantees = [
456 (
458 col("x"),
459 NullableInterval::NotNull {
460 values: Interval::try_new(
461 ScalarValue::Date32(Some(18628)),
462 ScalarValue::Date32(None),
463 )
464 .unwrap(),
465 },
466 ),
467 ];
468
469 let simplified_cases = &[
471 (col("x").lt(lit(ScalarValue::Date32(Some(18628)))), false),
472 (col("x").lt_eq(lit(ScalarValue::Date32(Some(17000)))), false),
473 (col("x").gt(lit(ScalarValue::Date32(Some(18627)))), true),
474 (col("x").gt_eq(lit(ScalarValue::Date32(Some(18628)))), true),
475 (col("x").eq(lit(ScalarValue::Date32(Some(17000)))), false),
476 (col("x").not_eq(lit(ScalarValue::Date32(Some(17000)))), true),
477 (
478 col("x").between(
479 lit(ScalarValue::Date32(Some(16000))),
480 lit(ScalarValue::Date32(Some(17000))),
481 ),
482 false,
483 ),
484 (
485 col("x").not_between(
486 lit(ScalarValue::Date32(Some(16000))),
487 lit(ScalarValue::Date32(Some(17000))),
488 ),
489 true,
490 ),
491 (
492 Expr::BinaryExpr(BinaryExpr {
493 left: Box::new(col("x")),
494 op: Operator::IsDistinctFrom,
495 right: Box::new(lit(ScalarValue::Null)),
496 }),
497 true,
498 ),
499 (
500 Expr::BinaryExpr(BinaryExpr {
501 left: Box::new(col("x")),
502 op: Operator::IsDistinctFrom,
503 right: Box::new(lit(ScalarValue::Date32(Some(17000)))),
504 }),
505 true,
506 ),
507 ];
508
509 validate_simplified_cases(&guarantees, simplified_cases);
510
511 let unchanged_cases = &[
512 col("x").lt(lit(ScalarValue::Date32(Some(19000)))),
513 col("x").lt_eq(lit(ScalarValue::Date32(Some(19000)))),
514 col("x").gt(lit(ScalarValue::Date32(Some(19000)))),
515 col("x").gt_eq(lit(ScalarValue::Date32(Some(19000)))),
516 col("x").eq(lit(ScalarValue::Date32(Some(19000)))),
517 col("x").not_eq(lit(ScalarValue::Date32(Some(19000)))),
518 col("x").between(
519 lit(ScalarValue::Date32(Some(18000))),
520 lit(ScalarValue::Date32(Some(19000))),
521 ),
522 col("x").not_between(
523 lit(ScalarValue::Date32(Some(18000))),
524 lit(ScalarValue::Date32(Some(19000))),
525 ),
526 ];
527
528 validate_unchanged_cases(&guarantees, unchanged_cases);
529 }
530
531 #[test]
532 fn test_inequalities_maybe_null() {
533 let guarantees = [
534 (
536 col("x"),
537 NullableInterval::MaybeNull {
538 values: Interval::try_new(
539 ScalarValue::from("abc"),
540 ScalarValue::from("def"),
541 )
542 .unwrap(),
543 },
544 ),
545 ];
546
547 let simplified_cases = &[
549 (
550 Expr::BinaryExpr(BinaryExpr {
551 left: Box::new(col("x")),
552 op: Operator::IsDistinctFrom,
553 right: Box::new(lit("z")),
554 }),
555 true,
556 ),
557 (
558 Expr::BinaryExpr(BinaryExpr {
559 left: Box::new(col("x")),
560 op: Operator::IsNotDistinctFrom,
561 right: Box::new(lit("z")),
562 }),
563 false,
564 ),
565 ];
566
567 validate_simplified_cases(&guarantees, simplified_cases);
568
569 let unchanged_cases = &[
570 col("x").lt(lit("z")),
571 col("x").lt_eq(lit("z")),
572 col("x").gt(lit("a")),
573 col("x").gt_eq(lit("a")),
574 col("x").eq(lit("abc")),
575 col("x").not_eq(lit("a")),
576 col("x").between(lit("a"), lit("z")),
577 col("x").not_between(lit("a"), lit("z")),
578 Expr::BinaryExpr(BinaryExpr {
579 left: Box::new(col("x")),
580 op: Operator::IsDistinctFrom,
581 right: Box::new(lit(ScalarValue::Null)),
582 }),
583 ];
584
585 validate_unchanged_cases(&guarantees, unchanged_cases);
586 }
587
588 #[test]
589 fn test_column_single_value() {
590 let scalars = [
591 ScalarValue::Null,
592 ScalarValue::Int32(Some(1)),
593 ScalarValue::Boolean(Some(true)),
594 ScalarValue::Boolean(None),
595 ScalarValue::from("abc"),
596 ScalarValue::LargeUtf8(Some("def".to_string())),
597 ScalarValue::Date32(Some(18628)),
598 ScalarValue::Date32(None),
599 ScalarValue::Decimal128(Some(1000), 19, 2),
600 ];
601
602 for scalar in scalars {
603 let guarantees = [(col("x"), NullableInterval::from(scalar.clone()))];
604
605 let output = rewrite_with_guarantees(col("x"), guarantees.iter())
606 .data()
607 .unwrap();
608 assert_eq!(output, Expr::Literal(scalar.clone(), None));
609 }
610 }
611
612 #[test]
613 fn test_in_list() {
614 let guarantees = [
615 (
617 col("x"),
618 NullableInterval::NotNull {
619 values: Interval::try_new(
620 ScalarValue::Int32(Some(1)),
621 ScalarValue::Int32(Some(10)),
622 )
623 .unwrap(),
624 },
625 ),
626 ];
627
628 let cases = &[
632 ("x", vec![9, 11], false, vec![9]),
634 ("x", vec![10, 2], false, vec![10, 2]),
636 ("x", vec![9, 11], true, vec![9]),
638 ("x", vec![0, 22], true, vec![]),
640 ];
641
642 for (column_name, starting_list, negated, expected_list) in cases {
643 let expr = col(*column_name).in_list(
644 starting_list
645 .iter()
646 .map(|v| lit(ScalarValue::Int32(Some(*v))))
647 .collect(),
648 *negated,
649 );
650 let output = rewrite_with_guarantees(expr.clone(), guarantees.iter())
651 .data()
652 .unwrap();
653 let expected_list = expected_list
654 .iter()
655 .map(|v| lit(ScalarValue::Int32(Some(*v))))
656 .collect();
657 assert_eq!(
658 output,
659 Expr::InList(InList {
660 expr: Box::new(col(*column_name)),
661 list: expected_list,
662 negated: *negated,
663 })
664 );
665 }
666 }
667}