1use std::{borrow::Cow, collections::HashMap};
23
24use datafusion_common::tree_node::{Transformed, TreeNodeRewriter};
25use datafusion_common::{DataFusionError, Result};
26use datafusion_expr::interval_arithmetic::{Interval, NullableInterval};
27use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr};
28
29pub struct GuaranteeRewriter<'a> {
43 guarantees: HashMap<&'a Expr, &'a NullableInterval>,
44}
45
46impl<'a> GuaranteeRewriter<'a> {
47 pub fn new(
48 guarantees: impl IntoIterator<Item = &'a (Expr, NullableInterval)>,
49 ) -> Self {
50 Self {
51 #[allow(clippy::map_identity)]
55 guarantees: guarantees.into_iter().map(|(k, v)| (k, v)).collect(),
56 }
57 }
58}
59
60impl TreeNodeRewriter for GuaranteeRewriter<'_> {
61 type Node = Expr;
62
63 fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
64 if self.guarantees.is_empty() {
65 return Ok(Transformed::no(expr));
66 }
67
68 match &expr {
69 Expr::IsNull(inner) => match self.guarantees.get(inner.as_ref()) {
70 Some(NullableInterval::Null { .. }) => Ok(Transformed::yes(lit(true))),
71 Some(NullableInterval::NotNull { .. }) => {
72 Ok(Transformed::yes(lit(false)))
73 }
74 _ => Ok(Transformed::no(expr)),
75 },
76 Expr::IsNotNull(inner) => match self.guarantees.get(inner.as_ref()) {
77 Some(NullableInterval::Null { .. }) => Ok(Transformed::yes(lit(false))),
78 Some(NullableInterval::NotNull { .. }) => Ok(Transformed::yes(lit(true))),
79 _ => Ok(Transformed::no(expr)),
80 },
81 Expr::Between(Between {
82 expr: inner,
83 negated,
84 low,
85 high,
86 }) => {
87 if let (Some(interval), Expr::Literal(low, _), Expr::Literal(high, _)) = (
88 self.guarantees.get(inner.as_ref()),
89 low.as_ref(),
90 high.as_ref(),
91 ) {
92 let expr_interval = NullableInterval::NotNull {
93 values: Interval::try_new(low.clone(), high.clone())?,
94 };
95
96 let contains = expr_interval.contains(*interval)?;
97
98 if contains.is_certainly_true() {
99 Ok(Transformed::yes(lit(!negated)))
100 } else if contains.is_certainly_false() {
101 Ok(Transformed::yes(lit(*negated)))
102 } else {
103 Ok(Transformed::no(expr))
104 }
105 } else {
106 Ok(Transformed::no(expr))
107 }
108 }
109
110 Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
111 let left_interval = self
114 .guarantees
115 .get(left.as_ref())
116 .map(|interval| Cow::Borrowed(*interval))
117 .or_else(|| {
118 if let Expr::Literal(value, _) = left.as_ref() {
119 Some(Cow::Owned(value.clone().into()))
120 } else {
121 None
122 }
123 });
124 let right_interval = self
125 .guarantees
126 .get(right.as_ref())
127 .map(|interval| Cow::Borrowed(*interval))
128 .or_else(|| {
129 if let Expr::Literal(value, _) = right.as_ref() {
130 Some(Cow::Owned(value.clone().into()))
131 } else {
132 None
133 }
134 });
135
136 match (left_interval, right_interval) {
137 (Some(left_interval), Some(right_interval)) => {
138 let result =
139 left_interval.apply_operator(op, right_interval.as_ref())?;
140 if result.is_certainly_true() {
141 Ok(Transformed::yes(lit(true)))
142 } else if result.is_certainly_false() {
143 Ok(Transformed::yes(lit(false)))
144 } else {
145 Ok(Transformed::no(expr))
146 }
147 }
148 _ => Ok(Transformed::no(expr)),
149 }
150 }
151
152 Expr::Column(_) => {
154 if let Some(interval) = self.guarantees.get(&expr) {
155 Ok(Transformed::yes(interval.single_value().map_or(expr, lit)))
156 } else {
157 Ok(Transformed::no(expr))
158 }
159 }
160
161 Expr::InList(InList {
162 expr: inner,
163 list,
164 negated,
165 }) => {
166 if let Some(interval) = self.guarantees.get(inner.as_ref()) {
167 let new_list: Vec<Expr> = list
169 .iter()
170 .filter_map(|expr| {
171 if let Expr::Literal(item, _) = expr {
172 match interval
173 .contains(NullableInterval::from(item.clone()))
174 {
175 Ok(interval) if interval.is_certainly_false() => None,
178 Ok(_) => Some(Ok(expr.clone())),
179 Err(e) => Some(Err(e)),
180 }
181 } else {
182 Some(Ok(expr.clone()))
183 }
184 })
185 .collect::<Result<_, DataFusionError>>()?;
186
187 Ok(Transformed::yes(Expr::InList(InList {
188 expr: inner.clone(),
189 list: new_list,
190 negated: *negated,
191 })))
192 } else {
193 Ok(Transformed::no(expr))
194 }
195 }
196
197 _ => Ok(Transformed::no(expr)),
198 }
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 use arrow::datatypes::DataType;
207 use datafusion_common::tree_node::{TransformedResult, TreeNode};
208 use datafusion_common::ScalarValue;
209 use datafusion_expr::{col, Operator};
210
211 #[test]
212 fn test_null_handling() {
213 let guarantees = vec![
215 (
218 col("x"),
219 NullableInterval::NotNull {
220 values: Interval::make_unbounded(&DataType::Boolean).unwrap(),
221 },
222 ),
223 ];
224 let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
225
226 let expr = col("x").is_null();
228 let output = expr.rewrite(&mut rewriter).data().unwrap();
229 assert_eq!(output, lit(false));
230
231 let expr = col("x").is_not_null();
233 let output = expr.rewrite(&mut rewriter).data().unwrap();
234 assert_eq!(output, lit(true));
235 }
236
237 fn validate_simplified_cases<T>(rewriter: &mut GuaranteeRewriter, cases: &[(Expr, T)])
238 where
239 ScalarValue: From<T>,
240 T: Clone,
241 {
242 for (expr, expected_value) in cases {
243 let output = expr.clone().rewrite(rewriter).data().unwrap();
244 let expected = lit(ScalarValue::from(expected_value.clone()));
245 assert_eq!(
246 output, expected,
247 "{expr} simplified to {output}, but expected {expected}"
248 );
249 }
250 }
251
252 fn validate_unchanged_cases(rewriter: &mut GuaranteeRewriter, cases: &[Expr]) {
253 for expr in cases {
254 let output = expr.clone().rewrite(rewriter).data().unwrap();
255 assert_eq!(
256 &output, expr,
257 "{expr} was simplified to {output}, but expected it to be unchanged"
258 );
259 }
260 }
261
262 #[test]
263 fn test_inequalities_non_null_unbounded() {
264 let guarantees = vec![
265 (
267 col("x"),
268 NullableInterval::NotNull {
269 values: Interval::try_new(
270 ScalarValue::Date32(Some(18628)),
271 ScalarValue::Date32(None),
272 )
273 .unwrap(),
274 },
275 ),
276 ];
277 let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
278
279 let simplified_cases = &[
281 (col("x").lt(lit(ScalarValue::Date32(Some(18628)))), false),
282 (col("x").lt_eq(lit(ScalarValue::Date32(Some(17000)))), false),
283 (col("x").gt(lit(ScalarValue::Date32(Some(18627)))), true),
284 (col("x").gt_eq(lit(ScalarValue::Date32(Some(18628)))), true),
285 (col("x").eq(lit(ScalarValue::Date32(Some(17000)))), false),
286 (col("x").not_eq(lit(ScalarValue::Date32(Some(17000)))), true),
287 (
288 col("x").between(
289 lit(ScalarValue::Date32(Some(16000))),
290 lit(ScalarValue::Date32(Some(17000))),
291 ),
292 false,
293 ),
294 (
295 col("x").not_between(
296 lit(ScalarValue::Date32(Some(16000))),
297 lit(ScalarValue::Date32(Some(17000))),
298 ),
299 true,
300 ),
301 (
302 Expr::BinaryExpr(BinaryExpr {
303 left: Box::new(col("x")),
304 op: Operator::IsDistinctFrom,
305 right: Box::new(lit(ScalarValue::Null)),
306 }),
307 true,
308 ),
309 (
310 Expr::BinaryExpr(BinaryExpr {
311 left: Box::new(col("x")),
312 op: Operator::IsDistinctFrom,
313 right: Box::new(lit(ScalarValue::Date32(Some(17000)))),
314 }),
315 true,
316 ),
317 ];
318
319 validate_simplified_cases(&mut rewriter, simplified_cases);
320
321 let unchanged_cases = &[
322 col("x").lt(lit(ScalarValue::Date32(Some(19000)))),
323 col("x").lt_eq(lit(ScalarValue::Date32(Some(19000)))),
324 col("x").gt(lit(ScalarValue::Date32(Some(19000)))),
325 col("x").gt_eq(lit(ScalarValue::Date32(Some(19000)))),
326 col("x").eq(lit(ScalarValue::Date32(Some(19000)))),
327 col("x").not_eq(lit(ScalarValue::Date32(Some(19000)))),
328 col("x").between(
329 lit(ScalarValue::Date32(Some(18000))),
330 lit(ScalarValue::Date32(Some(19000))),
331 ),
332 col("x").not_between(
333 lit(ScalarValue::Date32(Some(18000))),
334 lit(ScalarValue::Date32(Some(19000))),
335 ),
336 ];
337
338 validate_unchanged_cases(&mut rewriter, unchanged_cases);
339 }
340
341 #[test]
342 fn test_inequalities_maybe_null() {
343 let guarantees = vec![
344 (
346 col("x"),
347 NullableInterval::MaybeNull {
348 values: Interval::try_new(
349 ScalarValue::from("abc"),
350 ScalarValue::from("def"),
351 )
352 .unwrap(),
353 },
354 ),
355 ];
356 let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
357
358 let simplified_cases = &[
360 (
361 Expr::BinaryExpr(BinaryExpr {
362 left: Box::new(col("x")),
363 op: Operator::IsDistinctFrom,
364 right: Box::new(lit("z")),
365 }),
366 true,
367 ),
368 (
369 Expr::BinaryExpr(BinaryExpr {
370 left: Box::new(col("x")),
371 op: Operator::IsNotDistinctFrom,
372 right: Box::new(lit("z")),
373 }),
374 false,
375 ),
376 ];
377
378 validate_simplified_cases(&mut rewriter, simplified_cases);
379
380 let unchanged_cases = &[
381 col("x").lt(lit("z")),
382 col("x").lt_eq(lit("z")),
383 col("x").gt(lit("a")),
384 col("x").gt_eq(lit("a")),
385 col("x").eq(lit("abc")),
386 col("x").not_eq(lit("a")),
387 col("x").between(lit("a"), lit("z")),
388 col("x").not_between(lit("a"), lit("z")),
389 Expr::BinaryExpr(BinaryExpr {
390 left: Box::new(col("x")),
391 op: Operator::IsDistinctFrom,
392 right: Box::new(lit(ScalarValue::Null)),
393 }),
394 ];
395
396 validate_unchanged_cases(&mut rewriter, unchanged_cases);
397 }
398
399 #[test]
400 fn test_column_single_value() {
401 let scalars = [
402 ScalarValue::Null,
403 ScalarValue::Int32(Some(1)),
404 ScalarValue::Boolean(Some(true)),
405 ScalarValue::Boolean(None),
406 ScalarValue::from("abc"),
407 ScalarValue::LargeUtf8(Some("def".to_string())),
408 ScalarValue::Date32(Some(18628)),
409 ScalarValue::Date32(None),
410 ScalarValue::Decimal128(Some(1000), 19, 2),
411 ];
412
413 for scalar in scalars {
414 let guarantees = vec![(col("x"), NullableInterval::from(scalar.clone()))];
415 let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
416
417 let output = col("x").rewrite(&mut rewriter).data().unwrap();
418 assert_eq!(output, Expr::Literal(scalar.clone(), None));
419 }
420 }
421
422 #[test]
423 fn test_in_list() {
424 let guarantees = vec![
425 (
427 col("x"),
428 NullableInterval::NotNull {
429 values: Interval::try_new(
430 ScalarValue::Int32(Some(1)),
431 ScalarValue::Int32(Some(10)),
432 )
433 .unwrap(),
434 },
435 ),
436 ];
437 let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
438
439 let cases = &[
443 ("x", vec![9, 11], false, vec![9]),
445 ("x", vec![10, 2], false, vec![10, 2]),
447 ("x", vec![9, 11], true, vec![9]),
449 ("x", vec![0, 22], true, vec![]),
451 ];
452
453 for (column_name, starting_list, negated, expected_list) in cases {
454 let expr = col(*column_name).in_list(
455 starting_list
456 .iter()
457 .map(|v| lit(ScalarValue::Int32(Some(*v))))
458 .collect(),
459 *negated,
460 );
461 let output = expr.clone().rewrite(&mut rewriter).data().unwrap();
462 let expected_list = expected_list
463 .iter()
464 .map(|v| lit(ScalarValue::Int32(Some(*v))))
465 .collect();
466 assert_eq!(
467 output,
468 Expr::InList(InList {
469 expr: Box::new(col(*column_name)),
470 list: expected_list,
471 negated: *negated,
472 })
473 );
474 }
475 }
476}