1use crate::{Between, BinaryExpr, Expr, expr::InList, lit};
21use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
22use datafusion_common::{DataFusionError, HashMap, Result, ScalarValue};
23use datafusion_expr_common::interval_arithmetic::{Interval, NullableInterval};
24use std::borrow::Cow;
25
26pub struct GuaranteeRewriter<'a> {
30 guarantees: HashMap<&'a Expr, &'a NullableInterval>,
31}
32
33impl<'a> GuaranteeRewriter<'a> {
34 pub fn new(
35 guarantees: impl IntoIterator<Item = &'a (Expr, NullableInterval)>,
36 ) -> Self {
37 Self {
38 guarantees: guarantees.into_iter().map(|(k, v)| (k, v)).collect(),
39 }
40 }
41}
42
43pub fn rewrite_with_guarantees<'a>(
61 expr: Expr,
62 guarantees: impl IntoIterator<Item = &'a (Expr, NullableInterval)>,
63) -> Result<Transformed<Expr>> {
64 let guarantees_map: HashMap<&Expr, &NullableInterval> =
65 guarantees.into_iter().map(|(k, v)| (k, v)).collect();
66 rewrite_with_guarantees_map(expr, &guarantees_map)
67}
68
69pub fn rewrite_with_guarantees_map<'a>(
78 expr: Expr,
79 guarantees: &'a HashMap<&'a Expr, &'a NullableInterval>,
80) -> Result<Transformed<Expr>> {
81 if guarantees.is_empty() {
82 return Ok(Transformed::no(expr));
83 }
84
85 expr.transform_up(|e| rewrite_expr(e, guarantees))
86}
87
88impl TreeNodeRewriter for GuaranteeRewriter<'_> {
89 type Node = Expr;
90
91 fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
92 if self.guarantees.is_empty() {
93 return Ok(Transformed::no(expr));
94 }
95
96 rewrite_expr(expr, &self.guarantees)
97 }
98}
99
100fn rewrite_expr(
101 expr: Expr,
102 guarantees: &HashMap<&Expr, &NullableInterval>,
103) -> Result<Transformed<Expr>> {
104 if let Some(interval) = guarantees.get(&expr)
106 && let Some(value) = interval.single_value()
107 {
108 return Ok(Transformed::yes(lit(value)));
109 }
110
111 let result = match expr {
112 Expr::IsNull(inner) => match guarantees.get(inner.as_ref()) {
113 Some(NullableInterval::Null { .. }) => Transformed::yes(lit(true)),
114 Some(NullableInterval::NotNull { .. }) => Transformed::yes(lit(false)),
115 _ => Transformed::no(Expr::IsNull(inner)),
116 },
117 Expr::IsNotNull(inner) => match guarantees.get(inner.as_ref()) {
118 Some(NullableInterval::Null { .. }) => Transformed::yes(lit(false)),
119 Some(NullableInterval::NotNull { .. }) => Transformed::yes(lit(true)),
120 _ => Transformed::no(Expr::IsNotNull(inner)),
121 },
122 Expr::Between(b) => rewrite_between(b, guarantees)?,
123 Expr::BinaryExpr(b) => rewrite_binary_expr(b, guarantees)?,
124 Expr::InList(i) => rewrite_inlist(i, guarantees)?,
125 expr => Transformed::no(expr),
126 };
127 Ok(result)
128}
129
130fn rewrite_between(
131 between: Between,
132 guarantees: &HashMap<&Expr, &NullableInterval>,
133) -> Result<Transformed<Expr>> {
134 let (Some(expr_interval), Expr::Literal(low, _), Expr::Literal(high, _)) = (
135 guarantees.get(between.expr.as_ref()),
136 between.low.as_ref(),
137 between.high.as_ref(),
138 ) else {
139 return Ok(Transformed::no(Expr::Between(between)));
140 };
141
142 let low = ensure_typed_null(low, high)?;
144 let high = ensure_typed_null(high, &low)?;
145
146 let Ok(between_interval) = Interval::try_new(low, high) else {
147 return Ok(Transformed::no(Expr::Between(between)));
150 };
151
152 if between_interval.lower().is_null() && between_interval.upper().is_null() {
153 return Ok(Transformed::yes(lit(between_interval.lower().clone())));
154 }
155
156 let expr_interval = match expr_interval {
157 NullableInterval::Null { datatype } => {
158 return Ok(Transformed::yes(lit(
160 ScalarValue::try_new_null(datatype).unwrap_or(ScalarValue::Null)
161 )));
162 }
163 NullableInterval::MaybeNull { .. } => {
164 return Ok(Transformed::no(Expr::Between(between)));
166 }
167 NullableInterval::NotNull { values } => values,
168 };
169
170 let result = if between_interval.lower().is_null() {
171 let upper_bound = Interval::from(between_interval.upper().clone());
173 if expr_interval.gt(&upper_bound)?.eq(&Interval::TRUE) {
174 Transformed::yes(lit(between.negated))
176 } else if expr_interval.lt_eq(&upper_bound)?.eq(&Interval::TRUE) {
177 Transformed::yes(lit(ScalarValue::try_new_null(&expr_interval.data_type())
179 .unwrap_or(ScalarValue::Null)))
180 } else {
181 Transformed::no(Expr::Between(between))
183 }
184 } else if between_interval.upper().is_null() {
185 let lower_bound = Interval::from(between_interval.lower().clone());
187 if expr_interval.lt(&lower_bound)?.eq(&Interval::TRUE) {
188 Transformed::yes(lit(between.negated))
190 } else if expr_interval.gt_eq(&lower_bound)?.eq(&Interval::TRUE) {
191 Transformed::yes(lit(ScalarValue::try_new_null(&expr_interval.data_type())
193 .unwrap_or(ScalarValue::Null)))
194 } else {
195 Transformed::no(Expr::Between(between))
197 }
198 } else {
199 let contains = between_interval.contains(expr_interval)?;
200 if contains.eq(&Interval::TRUE) {
201 Transformed::yes(lit(!between.negated))
202 } else if contains.eq(&Interval::FALSE) {
203 Transformed::yes(lit(between.negated))
204 } else {
205 Transformed::no(Expr::Between(between))
206 }
207 };
208 Ok(result)
209}
210
211fn ensure_typed_null(
212 value: &ScalarValue,
213 other: &ScalarValue,
214) -> Result<ScalarValue, DataFusionError> {
215 Ok(
216 if value.data_type().is_null() && !other.data_type().is_null() {
217 ScalarValue::try_new_null(&other.data_type())?
218 } else {
219 value.clone()
220 },
221 )
222}
223
224fn rewrite_binary_expr(
225 binary: BinaryExpr,
226 guarantees: &HashMap<&Expr, &NullableInterval>,
227) -> Result<Transformed<Expr>, DataFusionError> {
228 let left_interval = guarantees
231 .get(binary.left.as_ref())
232 .map(|interval| Cow::Borrowed(*interval))
233 .or_else(|| {
234 if let Expr::Literal(value, _) = binary.left.as_ref() {
235 Some(Cow::Owned(value.clone().into()))
236 } else {
237 None
238 }
239 });
240 let right_interval = guarantees
241 .get(binary.right.as_ref())
242 .map(|interval| Cow::Borrowed(*interval))
243 .or_else(|| {
244 if let Expr::Literal(value, _) = binary.right.as_ref() {
245 Some(Cow::Owned(value.clone().into()))
246 } else {
247 None
248 }
249 });
250
251 if let (Some(left_interval), Some(right_interval)) = (left_interval, right_interval) {
252 let result = left_interval.apply_operator(&binary.op, right_interval.as_ref())?;
253 if result.is_certainly_true() {
254 return Ok(Transformed::yes(lit(true)));
255 } else if result.is_certainly_false() {
256 return Ok(Transformed::yes(lit(false)));
257 }
258 }
259 Ok(Transformed::no(Expr::BinaryExpr(binary)))
260}
261
262fn rewrite_inlist(
263 inlist: InList,
264 guarantees: &HashMap<&Expr, &NullableInterval>,
265) -> Result<Transformed<Expr>, DataFusionError> {
266 let Some(interval) = guarantees.get(inlist.expr.as_ref()) else {
267 return Ok(Transformed::no(Expr::InList(inlist)));
268 };
269
270 let InList {
271 expr,
272 list,
273 negated,
274 } = inlist;
275
276 let list: Vec<Expr> = list
278 .into_iter()
279 .filter_map(|expr| {
280 if let Expr::Literal(item, _) = &expr {
281 match interval.contains(NullableInterval::from(item.clone())) {
282 Ok(interval) if interval.is_certainly_false() => None,
285 Ok(_) => Some(Ok(expr)),
286 Err(e) => Some(Err(e)),
287 }
288 } else {
289 Some(Ok(expr))
290 }
291 })
292 .collect::<Result<_, DataFusionError>>()?;
293
294 Ok(Transformed::yes(Expr::InList(InList {
295 expr,
296 list,
297 negated,
298 })))
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 use crate::{Operator, col};
306 use datafusion_common::ScalarValue;
307 use datafusion_common::tree_node::TransformedResult;
308
309 #[test]
310 fn test_not_null_guarantee() {
311 let guarantees = [
313 (
316 col("x"),
317 NullableInterval::NotNull {
318 values: Interval::make(Some(1), Some(3)).unwrap(),
319 },
320 ),
321 ];
322
323 let is_null_cases = vec![
324 (col("x").is_null(), Some(lit(false))),
326 (col("x").is_not_null(), Some(lit(true))),
328 (col("x").between(lit(0), lit(10)), Some(lit(true))),
330 (col("x").between(lit(1), lit(-2)), None),
332 (
334 col("x").between(lit(ScalarValue::Null), lit(0)),
335 Some(lit(false)),
336 ),
337 (col("x").between(lit(ScalarValue::Null), lit(1)), None),
339 (col("x").between(lit(ScalarValue::Null), lit(2)), None),
341 (
343 col("x").between(lit(ScalarValue::Null), lit(3)),
344 Some(lit(ScalarValue::Int32(None))),
345 ),
346 (
348 col("x").between(lit(ScalarValue::Null), lit(4)),
349 Some(lit(ScalarValue::Int32(None))),
350 ),
351 (
353 col("x").between(lit(0), lit(ScalarValue::Null)),
354 Some(lit(ScalarValue::Int32(None))),
355 ),
356 (
358 col("x").between(lit(1), lit(ScalarValue::Null)),
359 Some(lit(ScalarValue::Int32(None))),
360 ),
361 (col("x").between(lit(2), lit(ScalarValue::Null)), None),
363 (col("x").between(lit(3), lit(ScalarValue::Null)), None),
365 (
367 col("x").between(lit(4), lit(ScalarValue::Null)),
368 Some(lit(false)),
369 ),
370 (
372 col("x").not_between(lit(ScalarValue::Null), lit(0)),
373 Some(lit(true)),
374 ),
375 (col("x").not_between(lit(ScalarValue::Null), lit(1)), None),
377 (col("x").not_between(lit(ScalarValue::Null), lit(2)), None),
379 (
381 col("x").not_between(lit(ScalarValue::Null), lit(3)),
382 Some(lit(ScalarValue::Int32(None))),
383 ),
384 (
386 col("x").not_between(lit(ScalarValue::Null), lit(4)),
387 Some(lit(ScalarValue::Int32(None))),
388 ),
389 (
391 col("x").not_between(lit(0), lit(ScalarValue::Null)),
392 Some(lit(ScalarValue::Int32(None))),
393 ),
394 (
396 col("x").not_between(lit(1), lit(ScalarValue::Null)),
397 Some(lit(ScalarValue::Int32(None))),
398 ),
399 (col("x").not_between(lit(2), lit(ScalarValue::Null)), None),
401 (col("x").not_between(lit(3), lit(ScalarValue::Null)), None),
403 (
405 col("x").not_between(lit(4), lit(ScalarValue::Null)),
406 Some(lit(true)),
407 ),
408 ];
409
410 for case in is_null_cases {
411 let output = rewrite_with_guarantees(case.0.clone(), guarantees.iter())
412 .data()
413 .unwrap();
414 let expected = match case.1 {
415 None => case.0.clone(),
416 Some(expected) => expected,
417 };
418
419 assert_eq!(output, expected, "Failed for {}", case.0);
420 }
421 }
422
423 fn validate_simplified_cases<T>(
424 guarantees: &[(Expr, NullableInterval)],
425 cases: &[(Expr, T)],
426 ) where
427 ScalarValue: From<T>,
428 T: Clone,
429 {
430 for (expr, expected_value) in cases {
431 let output = rewrite_with_guarantees(expr.clone(), guarantees.iter())
432 .data()
433 .unwrap();
434 let expected = lit(ScalarValue::from(expected_value.clone()));
435 assert_eq!(
436 output, expected,
437 "{expr} simplified to {output}, but expected {expected}"
438 );
439 }
440 }
441
442 fn validate_unchanged_cases(guarantees: &[(Expr, NullableInterval)], cases: &[Expr]) {
443 for expr in cases {
444 let output = rewrite_with_guarantees(expr.clone(), guarantees.iter())
445 .data()
446 .unwrap();
447 assert_eq!(
448 &output, expr,
449 "{expr} was simplified to {output}, but expected it to be unchanged"
450 );
451 }
452 }
453
454 #[test]
455 fn test_inequalities_non_null_unbounded() {
456 let guarantees = [
457 (
459 col("x"),
460 NullableInterval::NotNull {
461 values: Interval::try_new(
462 ScalarValue::Date32(Some(18628)),
463 ScalarValue::Date32(None),
464 )
465 .unwrap(),
466 },
467 ),
468 ];
469
470 let simplified_cases = &[
472 (col("x").lt(lit(ScalarValue::Date32(Some(18628)))), false),
473 (col("x").lt_eq(lit(ScalarValue::Date32(Some(17000)))), false),
474 (col("x").gt(lit(ScalarValue::Date32(Some(18627)))), true),
475 (col("x").gt_eq(lit(ScalarValue::Date32(Some(18628)))), true),
476 (col("x").eq(lit(ScalarValue::Date32(Some(17000)))), false),
477 (col("x").not_eq(lit(ScalarValue::Date32(Some(17000)))), true),
478 (
479 col("x").between(
480 lit(ScalarValue::Date32(Some(16000))),
481 lit(ScalarValue::Date32(Some(17000))),
482 ),
483 false,
484 ),
485 (
486 col("x").not_between(
487 lit(ScalarValue::Date32(Some(16000))),
488 lit(ScalarValue::Date32(Some(17000))),
489 ),
490 true,
491 ),
492 (
493 Expr::BinaryExpr(BinaryExpr {
494 left: Box::new(col("x")),
495 op: Operator::IsDistinctFrom,
496 right: Box::new(lit(ScalarValue::Null)),
497 }),
498 true,
499 ),
500 (
501 Expr::BinaryExpr(BinaryExpr {
502 left: Box::new(col("x")),
503 op: Operator::IsDistinctFrom,
504 right: Box::new(lit(ScalarValue::Date32(Some(17000)))),
505 }),
506 true,
507 ),
508 ];
509
510 validate_simplified_cases(&guarantees, simplified_cases);
511
512 let unchanged_cases = &[
513 col("x").lt(lit(ScalarValue::Date32(Some(19000)))),
514 col("x").lt_eq(lit(ScalarValue::Date32(Some(19000)))),
515 col("x").gt(lit(ScalarValue::Date32(Some(19000)))),
516 col("x").gt_eq(lit(ScalarValue::Date32(Some(19000)))),
517 col("x").eq(lit(ScalarValue::Date32(Some(19000)))),
518 col("x").not_eq(lit(ScalarValue::Date32(Some(19000)))),
519 col("x").between(
520 lit(ScalarValue::Date32(Some(18000))),
521 lit(ScalarValue::Date32(Some(19000))),
522 ),
523 col("x").not_between(
524 lit(ScalarValue::Date32(Some(18000))),
525 lit(ScalarValue::Date32(Some(19000))),
526 ),
527 ];
528
529 validate_unchanged_cases(&guarantees, unchanged_cases);
530 }
531
532 #[test]
533 fn test_inequalities_maybe_null() {
534 let guarantees = [
535 (
537 col("x"),
538 NullableInterval::MaybeNull {
539 values: Interval::try_new(
540 ScalarValue::from("abc"),
541 ScalarValue::from("def"),
542 )
543 .unwrap(),
544 },
545 ),
546 ];
547
548 let simplified_cases = &[
550 (
551 Expr::BinaryExpr(BinaryExpr {
552 left: Box::new(col("x")),
553 op: Operator::IsDistinctFrom,
554 right: Box::new(lit("z")),
555 }),
556 true,
557 ),
558 (
559 Expr::BinaryExpr(BinaryExpr {
560 left: Box::new(col("x")),
561 op: Operator::IsNotDistinctFrom,
562 right: Box::new(lit("z")),
563 }),
564 false,
565 ),
566 ];
567
568 validate_simplified_cases(&guarantees, simplified_cases);
569
570 let unchanged_cases = &[
571 col("x").lt(lit("z")),
572 col("x").lt_eq(lit("z")),
573 col("x").gt(lit("a")),
574 col("x").gt_eq(lit("a")),
575 col("x").eq(lit("abc")),
576 col("x").not_eq(lit("a")),
577 col("x").between(lit("a"), lit("z")),
578 col("x").not_between(lit("a"), lit("z")),
579 Expr::BinaryExpr(BinaryExpr {
580 left: Box::new(col("x")),
581 op: Operator::IsDistinctFrom,
582 right: Box::new(lit(ScalarValue::Null)),
583 }),
584 ];
585
586 validate_unchanged_cases(&guarantees, unchanged_cases);
587 }
588
589 #[test]
590 fn test_column_single_value() {
591 let scalars = [
592 ScalarValue::Null,
593 ScalarValue::Int32(Some(1)),
594 ScalarValue::Boolean(Some(true)),
595 ScalarValue::Boolean(None),
596 ScalarValue::from("abc"),
597 ScalarValue::LargeUtf8(Some("def".to_string())),
598 ScalarValue::Date32(Some(18628)),
599 ScalarValue::Date32(None),
600 ScalarValue::Decimal128(Some(1000), 19, 2),
601 ];
602
603 for scalar in scalars {
604 let guarantees = [(col("x"), NullableInterval::from(scalar.clone()))];
605
606 let output = rewrite_with_guarantees(col("x"), guarantees.iter())
607 .data()
608 .unwrap();
609 assert_eq!(output, Expr::Literal(scalar.clone(), None));
610 }
611 }
612
613 #[test]
614 fn test_in_list() {
615 let guarantees = [
616 (
618 col("x"),
619 NullableInterval::NotNull {
620 values: Interval::try_new(
621 ScalarValue::Int32(Some(1)),
622 ScalarValue::Int32(Some(10)),
623 )
624 .unwrap(),
625 },
626 ),
627 ];
628
629 let cases = &[
633 ("x", vec![9, 11], false, vec![9]),
635 ("x", vec![10, 2], false, vec![10, 2]),
637 ("x", vec![9, 11], true, vec![9]),
639 ("x", vec![0, 22], true, vec![]),
641 ];
642
643 for (column_name, starting_list, negated, expected_list) in cases {
644 let expr = col(*column_name).in_list(
645 starting_list
646 .iter()
647 .map(|v| lit(ScalarValue::Int32(Some(*v))))
648 .collect(),
649 *negated,
650 );
651 let output = rewrite_with_guarantees(expr.clone(), guarantees.iter())
652 .data()
653 .unwrap();
654 let expected_list = expected_list
655 .iter()
656 .map(|v| lit(ScalarValue::Int32(Some(*v))))
657 .collect();
658 assert_eq!(
659 output,
660 Expr::InList(InList {
661 expr: Box::new(col(*column_name)),
662 list: expected_list,
663 negated: *negated,
664 })
665 );
666 }
667 }
668}