1use std::{borrow::Cow, collections::HashMap};
23
24use datafusion_common::tree_node::{Transformed, TreeNodeRewriter};
25use datafusion_common::{DataFusionError, Result};
26use datafusion_expr::interval_arithmetic::{Interval, NullableInterval};
27use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr};
28
29pub struct GuaranteeRewriter<'a> {
43 guarantees: HashMap<&'a Expr, &'a NullableInterval>,
44}
45
46impl<'a> GuaranteeRewriter<'a> {
47 pub fn new(
48 guarantees: impl IntoIterator<Item = &'a (Expr, NullableInterval)>,
49 ) -> Self {
50 Self {
51 #[allow(clippy::map_identity)]
55 guarantees: guarantees.into_iter().map(|(k, v)| (k, v)).collect(),
56 }
57 }
58}
59
60impl TreeNodeRewriter for GuaranteeRewriter<'_> {
61 type Node = Expr;
62
63 fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
64 if self.guarantees.is_empty() {
65 return Ok(Transformed::no(expr));
66 }
67
68 match &expr {
69 Expr::IsNull(inner) => match self.guarantees.get(inner.as_ref()) {
70 Some(NullableInterval::Null { .. }) => Ok(Transformed::yes(lit(true))),
71 Some(NullableInterval::NotNull { .. }) => {
72 Ok(Transformed::yes(lit(false)))
73 }
74 _ => Ok(Transformed::no(expr)),
75 },
76 Expr::IsNotNull(inner) => match self.guarantees.get(inner.as_ref()) {
77 Some(NullableInterval::Null { .. }) => Ok(Transformed::yes(lit(false))),
78 Some(NullableInterval::NotNull { .. }) => Ok(Transformed::yes(lit(true))),
79 _ => Ok(Transformed::no(expr)),
80 },
81 Expr::Between(Between {
82 expr: inner,
83 negated,
84 low,
85 high,
86 }) => {
87 if let (Some(interval), Expr::Literal(low), Expr::Literal(high)) = (
88 self.guarantees.get(inner.as_ref()),
89 low.as_ref(),
90 high.as_ref(),
91 ) {
92 let expr_interval = NullableInterval::NotNull {
93 values: Interval::try_new(low.clone(), high.clone())?,
94 };
95
96 let contains = expr_interval.contains(*interval)?;
97
98 if contains.is_certainly_true() {
99 Ok(Transformed::yes(lit(!negated)))
100 } else if contains.is_certainly_false() {
101 Ok(Transformed::yes(lit(*negated)))
102 } else {
103 Ok(Transformed::no(expr))
104 }
105 } else {
106 Ok(Transformed::no(expr))
107 }
108 }
109
110 Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
111 let left_interval = self
114 .guarantees
115 .get(left.as_ref())
116 .map(|interval| Cow::Borrowed(*interval))
117 .or_else(|| {
118 if let Expr::Literal(value) = left.as_ref() {
119 Some(Cow::Owned(value.clone().into()))
120 } else {
121 None
122 }
123 });
124 let right_interval = self
125 .guarantees
126 .get(right.as_ref())
127 .map(|interval| Cow::Borrowed(*interval))
128 .or_else(|| {
129 if let Expr::Literal(value) = right.as_ref() {
130 Some(Cow::Owned(value.clone().into()))
131 } else {
132 None
133 }
134 });
135
136 match (left_interval, right_interval) {
137 (Some(left_interval), Some(right_interval)) => {
138 let result =
139 left_interval.apply_operator(op, right_interval.as_ref())?;
140 if result.is_certainly_true() {
141 Ok(Transformed::yes(lit(true)))
142 } else if result.is_certainly_false() {
143 Ok(Transformed::yes(lit(false)))
144 } else {
145 Ok(Transformed::no(expr))
146 }
147 }
148 _ => Ok(Transformed::no(expr)),
149 }
150 }
151
152 Expr::Column(_) => {
154 if let Some(interval) = self.guarantees.get(&expr) {
155 Ok(Transformed::yes(interval.single_value().map_or(expr, lit)))
156 } else {
157 Ok(Transformed::no(expr))
158 }
159 }
160
161 Expr::InList(InList {
162 expr: inner,
163 list,
164 negated,
165 }) => {
166 if let Some(interval) = self.guarantees.get(inner.as_ref()) {
167 let new_list: Vec<Expr> = list
169 .iter()
170 .filter_map(|expr| {
171 if let Expr::Literal(item) = expr {
172 match interval
173 .contains(NullableInterval::from(item.clone()))
174 {
175 Ok(interval) if interval.is_certainly_false() => None,
178 Ok(_) => Some(Ok(expr.clone())),
179 Err(e) => Some(Err(e)),
180 }
181 } else {
182 Some(Ok(expr.clone()))
183 }
184 })
185 .collect::<Result<_, DataFusionError>>()?;
186
187 Ok(Transformed::yes(Expr::InList(InList {
188 expr: inner.clone(),
189 list: new_list,
190 negated: *negated,
191 })))
192 } else {
193 Ok(Transformed::no(expr))
194 }
195 }
196
197 _ => Ok(Transformed::no(expr)),
198 }
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 use arrow::datatypes::DataType;
207 use datafusion_common::tree_node::{TransformedResult, TreeNode};
208 use datafusion_common::ScalarValue;
209 use datafusion_expr::{col, Operator};
210
211 #[test]
212 fn test_null_handling() {
213 let guarantees = vec![
215 (
218 col("x"),
219 NullableInterval::NotNull {
220 values: Interval::make_unbounded(&DataType::Boolean).unwrap(),
221 },
222 ),
223 ];
224 let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
225
226 let expr = col("x").is_null();
228 let output = expr.rewrite(&mut rewriter).data().unwrap();
229 assert_eq!(output, lit(false));
230
231 let expr = col("x").is_not_null();
233 let output = expr.rewrite(&mut rewriter).data().unwrap();
234 assert_eq!(output, lit(true));
235 }
236
237 fn validate_simplified_cases<T>(rewriter: &mut GuaranteeRewriter, cases: &[(Expr, T)])
238 where
239 ScalarValue: From<T>,
240 T: Clone,
241 {
242 for (expr, expected_value) in cases {
243 let output = expr.clone().rewrite(rewriter).data().unwrap();
244 let expected = lit(ScalarValue::from(expected_value.clone()));
245 assert_eq!(
246 output, expected,
247 "{} simplified to {}, but expected {}",
248 expr, output, expected
249 );
250 }
251 }
252
253 fn validate_unchanged_cases(rewriter: &mut GuaranteeRewriter, cases: &[Expr]) {
254 for expr in cases {
255 let output = expr.clone().rewrite(rewriter).data().unwrap();
256 assert_eq!(
257 &output, expr,
258 "{} was simplified to {}, but expected it to be unchanged",
259 expr, output
260 );
261 }
262 }
263
264 #[test]
265 fn test_inequalities_non_null_unbounded() {
266 let guarantees = vec![
267 (
269 col("x"),
270 NullableInterval::NotNull {
271 values: Interval::try_new(
272 ScalarValue::Date32(Some(18628)),
273 ScalarValue::Date32(None),
274 )
275 .unwrap(),
276 },
277 ),
278 ];
279 let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
280
281 let simplified_cases = &[
283 (col("x").lt(lit(ScalarValue::Date32(Some(18628)))), false),
284 (col("x").lt_eq(lit(ScalarValue::Date32(Some(17000)))), false),
285 (col("x").gt(lit(ScalarValue::Date32(Some(18627)))), true),
286 (col("x").gt_eq(lit(ScalarValue::Date32(Some(18628)))), true),
287 (col("x").eq(lit(ScalarValue::Date32(Some(17000)))), false),
288 (col("x").not_eq(lit(ScalarValue::Date32(Some(17000)))), true),
289 (
290 col("x").between(
291 lit(ScalarValue::Date32(Some(16000))),
292 lit(ScalarValue::Date32(Some(17000))),
293 ),
294 false,
295 ),
296 (
297 col("x").not_between(
298 lit(ScalarValue::Date32(Some(16000))),
299 lit(ScalarValue::Date32(Some(17000))),
300 ),
301 true,
302 ),
303 (
304 Expr::BinaryExpr(BinaryExpr {
305 left: Box::new(col("x")),
306 op: Operator::IsDistinctFrom,
307 right: Box::new(lit(ScalarValue::Null)),
308 }),
309 true,
310 ),
311 (
312 Expr::BinaryExpr(BinaryExpr {
313 left: Box::new(col("x")),
314 op: Operator::IsDistinctFrom,
315 right: Box::new(lit(ScalarValue::Date32(Some(17000)))),
316 }),
317 true,
318 ),
319 ];
320
321 validate_simplified_cases(&mut rewriter, simplified_cases);
322
323 let unchanged_cases = &[
324 col("x").lt(lit(ScalarValue::Date32(Some(19000)))),
325 col("x").lt_eq(lit(ScalarValue::Date32(Some(19000)))),
326 col("x").gt(lit(ScalarValue::Date32(Some(19000)))),
327 col("x").gt_eq(lit(ScalarValue::Date32(Some(19000)))),
328 col("x").eq(lit(ScalarValue::Date32(Some(19000)))),
329 col("x").not_eq(lit(ScalarValue::Date32(Some(19000)))),
330 col("x").between(
331 lit(ScalarValue::Date32(Some(18000))),
332 lit(ScalarValue::Date32(Some(19000))),
333 ),
334 col("x").not_between(
335 lit(ScalarValue::Date32(Some(18000))),
336 lit(ScalarValue::Date32(Some(19000))),
337 ),
338 ];
339
340 validate_unchanged_cases(&mut rewriter, unchanged_cases);
341 }
342
343 #[test]
344 fn test_inequalities_maybe_null() {
345 let guarantees = vec![
346 (
348 col("x"),
349 NullableInterval::MaybeNull {
350 values: Interval::try_new(
351 ScalarValue::from("abc"),
352 ScalarValue::from("def"),
353 )
354 .unwrap(),
355 },
356 ),
357 ];
358 let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
359
360 let simplified_cases = &[
362 (
363 Expr::BinaryExpr(BinaryExpr {
364 left: Box::new(col("x")),
365 op: Operator::IsDistinctFrom,
366 right: Box::new(lit("z")),
367 }),
368 true,
369 ),
370 (
371 Expr::BinaryExpr(BinaryExpr {
372 left: Box::new(col("x")),
373 op: Operator::IsNotDistinctFrom,
374 right: Box::new(lit("z")),
375 }),
376 false,
377 ),
378 ];
379
380 validate_simplified_cases(&mut rewriter, simplified_cases);
381
382 let unchanged_cases = &[
383 col("x").lt(lit("z")),
384 col("x").lt_eq(lit("z")),
385 col("x").gt(lit("a")),
386 col("x").gt_eq(lit("a")),
387 col("x").eq(lit("abc")),
388 col("x").not_eq(lit("a")),
389 col("x").between(lit("a"), lit("z")),
390 col("x").not_between(lit("a"), lit("z")),
391 Expr::BinaryExpr(BinaryExpr {
392 left: Box::new(col("x")),
393 op: Operator::IsDistinctFrom,
394 right: Box::new(lit(ScalarValue::Null)),
395 }),
396 ];
397
398 validate_unchanged_cases(&mut rewriter, unchanged_cases);
399 }
400
401 #[test]
402 fn test_column_single_value() {
403 let scalars = [
404 ScalarValue::Null,
405 ScalarValue::Int32(Some(1)),
406 ScalarValue::Boolean(Some(true)),
407 ScalarValue::Boolean(None),
408 ScalarValue::from("abc"),
409 ScalarValue::LargeUtf8(Some("def".to_string())),
410 ScalarValue::Date32(Some(18628)),
411 ScalarValue::Date32(None),
412 ScalarValue::Decimal128(Some(1000), 19, 2),
413 ];
414
415 for scalar in scalars {
416 let guarantees = vec![(col("x"), NullableInterval::from(scalar.clone()))];
417 let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
418
419 let output = col("x").rewrite(&mut rewriter).data().unwrap();
420 assert_eq!(output, Expr::Literal(scalar.clone()));
421 }
422 }
423
424 #[test]
425 fn test_in_list() {
426 let guarantees = vec![
427 (
429 col("x"),
430 NullableInterval::NotNull {
431 values: Interval::try_new(
432 ScalarValue::Int32(Some(1)),
433 ScalarValue::Int32(Some(10)),
434 )
435 .unwrap(),
436 },
437 ),
438 ];
439 let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
440
441 let cases = &[
445 ("x", vec![9, 11], false, vec![9]),
447 ("x", vec![10, 2], false, vec![10, 2]),
449 ("x", vec![9, 11], true, vec![9]),
451 ("x", vec![0, 22], true, vec![]),
453 ];
454
455 for (column_name, starting_list, negated, expected_list) in cases {
456 let expr = col(*column_name).in_list(
457 starting_list
458 .iter()
459 .map(|v| lit(ScalarValue::Int32(Some(*v))))
460 .collect(),
461 *negated,
462 );
463 let output = expr.clone().rewrite(&mut rewriter).data().unwrap();
464 let expected_list = expected_list
465 .iter()
466 .map(|v| lit(ScalarValue::Int32(Some(*v))))
467 .collect();
468 assert_eq!(
469 output,
470 Expr::InList(InList {
471 expr: Box::new(col(*column_name)),
472 list: expected_list,
473 negated: *negated,
474 })
475 );
476 }
477 }
478}