datafusion_materialized_views/rewrite/normal_form.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18/*!
19
20This module contains code primarily used for view matching. We implement the view matching algorithm from [this paper](https://courses.cs.washington.edu/courses/cse591d/01sp/opt_views.pdf),
21which provides a method for determining when one Select-Project-Join query can be rewritten in terms of another Select-Project-Join query.
22
23The implementation is contained in [`SpjNormalForm::rewrite_from`]. The method can be summarized as follows:
241. Compute column equivalence classes for the query and the view.
252. Compute range intervals for the query and the view.
263. (Equijoin subsumption test) Check that each column equivalence class of the view is a subset of a column equivalence class of the query.
274. (Range subsumption test) Check that each range of the view contains the corresponding range from the query.
285. (Residual subsumption test) Check that every filter in the view that is not a column equivalence relation or a range filter matches a filter from the query.
296. Compute any compensating filters needed in order to restrict the view's rows to match the query.
307. Check that the output of the query, and the compensating filters, can be rewritten using the view's columns as inputs.
31
32# Example
33
34Consider the following table:
35
36```sql
37CREATE TABLE example (
38 l_orderkey INT,
39 l_partkey INT,
40 l_shipdate DATE,
41 l_quantity DOUBLE,
42 l_extendedprice DOUBLE,
43 o_custkey INT,
44 o_orderkey INT,
45 o_orderdate DATE,
46 p_name VARCHAR,
47 p_partkey INT,
48)
49```
50
51And consider the follow view:
52
53```sql
54CREATE VIEW mv AS SELECT
55 l_orderkey,
56 o_custkey,
57 l_partkey,
58 l_shipdate, o_orderdate,
59 l_quantity*l_extendedprice AS gross_revenue
60FROM example
61WHERE
62 l_orderkey = o_orderkey AND
63 l_partkey = p_partkey AND
64 p_partkey >= 150 AND
65 o_custkey >= 50 AND
66 o_custkey <= 500 AND
67 p_name LIKE '%abc%'
68```
69
70During analysis, we look at the implied equivalence classes and possible range of values for each equivalence class.
71For this view, the following nontrivial equivalence classes are generated:
72 * `{l_orderkey, o_orderkey}`
73 * `{l_partkey, p_partkey}`
74
75Note that all other columns have their own singleton equivalence classes, but are not shown here.
76Likewise, the following nontrivial ranges are generated:
77 * `150 <= {l_partkey, p_partkey} < inf`
78 * `50 <= {o_custkey} <= 500`
79
80The rest of the equivalence classes are considered to have ranges of (-inf, inf).
81The remaining filter `p_name LIKE '%abc%'` is considered 'residual' as it is not a column equivalence nor a range filter.
82
83Now consider the following query, which we will rewrite to use the view:
84
85```sql
86SELECT
87 l_orderkey,
88 o_custkey,
89 l_partkey,
90 l_quantity*l_extendedprice
91FROM example
92WHERE
93 l_orderkey = o_orderkey AND
94 l_partkey = p_partkey AND
95 l_partkey >= 150 AND
96 l_partkey <= 160 AND
97 o_custkey = 123 AND
98 o_orderdate = l_shipdate AND
99 p_name like '%abc%' AND
100 l_quantity*l_extendedprice > 100
101````
102
103This generates the following equivalence classes:
104 * `{l_orderkey, o_orderkey}`
105 * `{l_partkey, p_partkey}`
106 * `{o_orderdate, l_shipdate}`
107
108And the following ranges:
109 * `150 <= {l_partkey, p_partkey} <= 160`
110 * `123 <= {o_custkey} <= 123`
111
112As before, we still have the residual filter `p_name LIKE '%abc'`. However, note that `l_quantity*l_extendedprice > 100` is also
113a residual filter, as it is not a range filter on a column -- it's a range filter on a mathematical expression.
114
115We perform the three subsumption tests:
116 * Equijoin subsumption test:
117 * View equivalence classes: `{l_orderkey, o_orderkey}, {l_partkey, p_partkey}`
118 * Query equivalence classes: `{l_orderkey, o_orderkey}, {l_partkey, p_partkey}, {o_orderdate, l_shipdate}`
119 * Every view equivalence class is a subset of one from the query, so the test is passed.
120 * We generate the compensating filter `o_orderdate = l_shipdate`.
121 * Range subsumption test:
122 * View ranges:
123 * `150 <= {l_partkey, p_partkey} < inf`
124 * `50 <= {o_custkey} <= 500`
125 * Query ranges:
126 * `150 <= {l_partkey, p_partkey} <= 160`
127 * `123 <= {o_custkey} <= 123`
128 * Both of the view's ranges contain corresponding ranges from the query, therefore the test is passed.
129 * Since they're both strict inclusions, we include them both as compensating filters.
130 * Residual subsumption test:
131 * View residuals: `p_name LIKE '%abc'`
132 * Query residuals: `p_name LIKE '%abc'`, `l_quantity*l_extendedprice > 100`
133 * Every view residual has a matching residual from the query, and the test is passed.
134 * The leftover residual in the query, `l_quantity*l_extendedprice > 100`, is included as a compensating filter.
135
136Ultimately we have the following compensating filters:
137 * `o_orderdate = l_shipdate`
138 * `150 <= {l_partkey, p_partkey} <= 160`
139 * `123 <= {o_custkey} <= 123`
140 * `l_quantity*l_extendedprice > 100`
141
142The final check is to ensure that the output of our query can be computed from the view. This includes
143any expressions used in the compensating filters.
144This is a relatively simple check that mostly involves rewriting expressions to use columns from the view,
145and ensuring that no references to the original tables are left.
146
147This example is included as a unit test. After rewriting the query to use the view, the resulting plan looks like this:
148
149```text
150+---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
151| plan_type | plan |
152+---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
153| logical_plan | Projection: mv.l_orderkey AS l_orderkey, mv.o_custkey AS o_custkey, mv.l_partkey AS l_partkey, mv.gross_revenue AS example.l_quantity * example.l_extendedprice |
154| | Filter: mv.o_orderdate = mv.l_shipdate AND mv.l_partkey >= Int32(150) AND mv.l_partkey <= Int32(160) AND mv.o_custkey >= Int32(123) AND mv.o_custkey <= Int32(123) AND mv.gross_revenue > Float64(100) |
155| | TableScan: mv projection=[l_orderkey, o_custkey, l_partkey, l_shipdate, o_orderdate, gross_revenue] |
156| physical_plan | ProjectionExec: expr=[l_orderkey@0 as l_orderkey, o_custkey@1 as o_custkey, l_partkey@2 as l_partkey, gross_revenue@5 as example.l_quantity * example.l_extendedprice] |
157| | CoalesceBatchesExec: target_batch_size=8192 |
158| | FilterExec: o_orderdate@4 = l_shipdate@3 AND l_partkey@2 >= 150 AND l_partkey@2 <= 160 AND o_custkey@1 >= 123 AND o_custkey@1 <= 123 AND gross_revenue@5 > 100 |
159| | MemoryExec: partitions=16, partition_sizes=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] |
160| | |
161+---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
162```
163
164As one can see, all compensating filters are included, and the query only uses the view.
165
166*/
167
168use std::{
169 collections::{BTreeSet, HashMap, HashSet},
170 sync::Arc,
171};
172
173use datafusion_common::{
174 tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter},
175 Column, DFSchema, DataFusionError, ExprSchema, Result, ScalarValue, TableReference,
176};
177use datafusion_expr::{
178 interval_arithmetic::{satisfy_greater, Interval},
179 lit,
180 utils::split_conjunction,
181 BinaryExpr, Expr, LogicalPlan, LogicalPlanBuilder, Operator, TableScan, TableSource,
182};
183use itertools::Itertools;
184
185/// A normalized representation of a plan containing only Select/Project/Join in the relational algebra sense.
186/// In DataFusion terminology this also includes Filter nodes.
187/// Joins are not currently supported, but are planned.
188#[derive(Debug, Clone)]
189pub struct SpjNormalForm {
190 output_schema: Arc<DFSchema>,
191 output_exprs: Vec<Expr>,
192 referenced_tables: Vec<TableReference>,
193 predicate: Predicate,
194}
195
196/// Rewrite an expression to re-use output columns from this plan, where possible.
197impl TreeNodeRewriter for &SpjNormalForm {
198 type Node = Expr;
199
200 fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
201 Ok(match self.output_exprs.iter().position(|x| x == &node) {
202 Some(idx) => Transformed::yes(Expr::Column(Column::new_unqualified(
203 self.output_schema.field(idx).name().clone(),
204 ))),
205 None => Transformed::no(node),
206 })
207 }
208}
209
210impl SpjNormalForm {
211 /// Schema of data output by this plan.
212 pub fn output_schema(&self) -> &Arc<DFSchema> {
213 &self.output_schema
214 }
215
216 /// Expressions output by this plan.
217 /// These expressions can be used to rewrite this plan as a cross join followed by a projection;
218 /// however, this does not include any filters in the original plan, so the result will be a superset.
219 pub fn output_exprs(&self) -> &[Expr] {
220 &self.output_exprs
221 }
222
223 /// All tables referenced in this plan.
224 pub fn referenced_tables(&self) -> &[TableReference] {
225 &self.referenced_tables
226 }
227
228 /// Analyze an existing `LogicalPlan` and rewrite it in select-project-join normal form.
229 pub fn new(original_plan: &LogicalPlan) -> Result<Self> {
230 let predicate = Predicate::new(original_plan)?;
231 let output_exprs = get_output_exprs(original_plan)?
232 .into_iter()
233 .map(|expr| predicate.normalize_expr(expr))
234 .collect();
235
236 let mut referenced_tables = vec![];
237 original_plan
238 .apply(|plan| {
239 if let LogicalPlan::TableScan(scan) = plan {
240 referenced_tables.push(scan.table_name.clone());
241 }
242
243 Ok(TreeNodeRecursion::Continue)
244 })
245 // No chance of error since we never return Err -- this unwrap is safe
246 .unwrap();
247
248 Ok(Self {
249 output_schema: Arc::clone(original_plan.schema()),
250 output_exprs,
251 referenced_tables,
252 predicate,
253 })
254 }
255
256 /// Rewrite this plan as as selection/projection on top of another plan,
257 /// which we use `qualifier` to refer to.
258 /// This is useful for rewriting queries to use materialized views.
259 pub fn rewrite_from(
260 &self,
261 mut other: &Self,
262 qualifier: TableReference,
263 source: Arc<dyn TableSource>,
264 ) -> Result<Option<LogicalPlan>> {
265 log::trace!("rewriting from {qualifier}");
266 let mut new_output_exprs = vec![];
267 // check that our output exprs are sub-expressions of the other one's output exprs
268 for (i, output_expr) in self.output_exprs.iter().enumerate() {
269 let new_output_expr = other
270 .predicate
271 .normalize_expr(output_expr.clone())
272 .rewrite(&mut other)?
273 .data;
274
275 // Check that all references to the original tables have been replaced.
276 // All remaining column expressions should be unqualified, which indicates
277 // that they refer to the output of the sub-plan (in this case the view)
278 if new_output_expr
279 .column_refs()
280 .iter()
281 .any(|c| c.relation.is_some())
282 {
283 return Ok(None);
284 }
285
286 let column = &self.output_schema.columns()[i];
287 new_output_exprs.push(
288 new_output_expr.alias_qualified(column.relation.clone(), column.name.clone()),
289 );
290 }
291
292 log::trace!("passed output rewrite");
293
294 // Check the subsumption tests, and compute any auxiliary needed filter expressions.
295 // If we pass all three subsumption tests, this plan's output is a subset of the other
296 // plan's output.
297 let ((eq_filters, range_filters), residual_filters) = match self
298 .predicate
299 .equijoin_subsumption_test(&other.predicate)
300 .zip(self.predicate.range_subsumption_test(&other.predicate)?)
301 .zip(self.predicate.residual_subsumption_test(&other.predicate))
302 {
303 None => return Ok(None),
304 Some(filters) => filters,
305 };
306
307 log::trace!("passed subsumption tests");
308
309 let all_filters = eq_filters
310 .into_iter()
311 .chain(range_filters)
312 .chain(residual_filters)
313 .map(|expr| expr.rewrite(&mut other).unwrap().data)
314 .reduce(|a, b| a.and(b));
315
316 if all_filters
317 .as_ref()
318 .map(|expr| expr.column_refs())
319 .is_some_and(|columns| columns.iter().any(|c| c.relation.is_some()))
320 {
321 return Ok(None);
322 }
323
324 let mut builder = LogicalPlanBuilder::scan(qualifier, source, None)?;
325
326 if let Some(filter) = all_filters {
327 builder = builder.filter(filter)?;
328 }
329
330 builder.project(new_output_exprs)?.build().map(Some)
331 }
332}
333
334/// Stores information on filters from a Select-Project-Join plan.
335#[derive(Debug, Clone)]
336struct Predicate {
337 schema: DFSchema,
338 /// List of column equivalence classes.
339 eq_classes: Vec<ColumnEquivalenceClass>,
340 /// Reverse lookup by eq class elements
341 eq_class_idx_by_column: HashMap<Column, usize>,
342 /// Stores (possibly empty) intervals describing each equivalence class.
343 ranges_by_equivalence_class: Vec<Option<Interval>>,
344 /// Filter expressions that aren't column equality predicates or range filters.
345 residuals: HashSet<Expr>,
346}
347
348impl Predicate {
349 fn new(plan: &LogicalPlan) -> Result<Self> {
350 let mut schema = DFSchema::empty();
351 plan.apply(|plan| {
352 if let LogicalPlan::TableScan(scan) = plan {
353 schema = schema.join(&scan.projected_schema)?;
354 }
355
356 Ok(TreeNodeRecursion::Continue)
357 })?;
358
359 let mut new = Self {
360 schema,
361 eq_classes: vec![],
362 eq_class_idx_by_column: HashMap::default(),
363 ranges_by_equivalence_class: vec![],
364 residuals: HashSet::new(),
365 };
366
367 // Collect all referenced columns
368 plan.apply(|plan| {
369 if let LogicalPlan::TableScan(scan) = plan {
370 for (i, column) in scan.projected_schema.columns().iter().enumerate() {
371 new.eq_classes
372 .push(ColumnEquivalenceClass::new_singleton(column.clone()));
373 new.eq_class_idx_by_column.insert(column.clone(), i);
374 new.ranges_by_equivalence_class
375 .push(Some(Interval::make_unbounded(
376 scan.projected_schema.data_type(column)?,
377 )?));
378 }
379 }
380
381 Ok(TreeNodeRecursion::Continue)
382 })?;
383
384 // Collect any filters
385 plan.apply(|plan| {
386 let filters = match plan {
387 LogicalPlan::TableScan(scan) => scan.filters.as_slice(),
388 LogicalPlan::Filter(filter) => core::slice::from_ref(&filter.predicate),
389 LogicalPlan::Join(_join) => {
390 return Err(DataFusionError::Internal(
391 "joins are not supported yet".to_string(),
392 ))
393 }
394 LogicalPlan::Projection(_) => &[],
395 _ => {
396 return Err(DataFusionError::Plan(format!(
397 "unsupported logical plan: {}",
398 plan.display()
399 )))
400 }
401 };
402
403 for expr in filters.iter().flat_map(split_conjunction) {
404 new.insert_conjuct(expr)?;
405 }
406
407 Ok(TreeNodeRecursion::Continue)
408 })?;
409
410 Ok(new)
411 }
412
413 fn class_for_column(&self, col: &Column) -> Option<&ColumnEquivalenceClass> {
414 self.eq_class_idx_by_column
415 .get(col)
416 .and_then(|&idx| self.eq_classes.get(idx))
417 }
418
419 /// Add a new column equivalence
420 fn add_equivalence(&mut self, c1: &Column, c2: &Column) -> Result<()> {
421 match (
422 self.eq_class_idx_by_column.get(c1),
423 self.eq_class_idx_by_column.get(c2),
424 ) {
425 (None, None) => {
426 // Make a new eq class [c1, c2]
427 self.eq_classes
428 .push(ColumnEquivalenceClass::new([c1.clone(), c2.clone()]));
429 self.ranges_by_equivalence_class
430 .push(Some(Interval::make_unbounded(
431 self.schema.field_from_column(c1).unwrap().data_type(),
432 )?));
433 }
434
435 // These two cases are just adding a column to an existing class
436 (None, Some(&idx)) => {
437 self.eq_classes[idx].columns.insert(c1.clone());
438 }
439 (Some(&idx), None) => {
440 self.eq_classes[idx].columns.insert(c2.clone());
441 }
442 (Some(&i), Some(&j)) => {
443 // We need to merge two existing column eq classes.
444
445 // Delete the eq class with a larger index,
446 // so that the other one has its position preserved.
447 // Not necessary, but it's just a little simpler this way
448 let (i, j) = if i < j { (i, j) } else { (j, i) };
449
450 // Merge the deleted eq class with its new partner
451 let new_columns = self.eq_classes.remove(j).columns;
452 self.eq_classes[i].columns.extend(new_columns.clone());
453 for column in new_columns {
454 self.eq_class_idx_by_column.insert(column, i);
455 }
456 // update all moved entries
457 for idx in self.eq_class_idx_by_column.values_mut() {
458 if *idx > j {
459 *idx -= 1;
460 }
461 }
462
463 // Merge ranges
464 // Now that we know the two equivalence classes are equal,
465 // the new range is the intersection of the existing two ranges.
466 self.ranges_by_equivalence_class[i] = self.ranges_by_equivalence_class[i]
467 .clone()
468 .zip(self.ranges_by_equivalence_class.remove(j))
469 .and_then(|(range, other_range)| range.intersect(other_range).transpose())
470 .transpose()?;
471 }
472 }
473
474 Ok(())
475 }
476
477 /// Update range for a column's equivalence class
478 fn add_range(&mut self, c: &Column, op: &Operator, value: &ScalarValue) -> Result<()> {
479 // first coerce the value if needed
480 let value = value.cast_to(self.schema.data_type(c)?)?;
481 let range = self
482 .eq_class_idx_by_column
483 .get(c)
484 .and_then(|&idx| self.ranges_by_equivalence_class.get_mut(idx))
485 .unwrap();
486 let new_range = match op {
487 Operator::Eq => Interval::try_new(value.clone(), value.clone()),
488 Operator::LtEq => {
489 Interval::try_new(ScalarValue::try_from(value.data_type())?, value.clone())
490 }
491 Operator::GtEq => {
492 Interval::try_new(value.clone(), ScalarValue::try_from(value.data_type())?)
493 }
494
495 // Note: This is a roundabout way (read: hack) to construct an open Interval.
496 // DataFusion's Interval type represents closed intervals,
497 // so handling of open intervals is done by adding/subtracting the smallest increment.
498 // However, there is not really a public API to do this,
499 // other than the satisfy_greater method.
500 Operator::Lt => Ok(
501 match satisfy_greater(
502 &Interval::try_new(value.clone(), value.clone())?,
503 &Interval::make_unbounded(&value.data_type())?,
504 true,
505 )? {
506 Some((_, range)) => range,
507 None => {
508 *range = None;
509 return Ok(());
510 }
511 },
512 ),
513 // Same thing as above.
514 Operator::Gt => Ok(
515 match satisfy_greater(
516 &Interval::make_unbounded(&value.data_type())?,
517 &Interval::try_new(value.clone(), value.clone())?,
518 true,
519 )? {
520 Some((range, _)) => range,
521 None => {
522 *range = None;
523 return Ok(());
524 }
525 },
526 ),
527 _ => Err(DataFusionError::Plan(
528 "unsupported binary expression".to_string(),
529 )),
530 }?;
531
532 *range = match range {
533 None => Some(new_range),
534 Some(range) => range.intersect(new_range)?,
535 };
536
537 Ok(())
538 }
539
540 /// Add a generic filter expression to our collection of filters.
541 /// A conjunct is a term T_i of an expression T_1 AND T_2 AND T_3 AND ...
542 fn insert_conjuct(&mut self, expr: &Expr) -> Result<()> {
543 match expr {
544 Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
545 self.insert_binary_expr(left, *op, right)?;
546 }
547 Expr::Not(e) => match e.as_ref() {
548 Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
549 if let Some(negated) = op.negate() {
550 self.insert_binary_expr(left, negated, right)?;
551 } else {
552 self.residuals.insert(expr.clone());
553 }
554 }
555 _ => {
556 self.residuals.insert(expr.clone());
557 }
558 },
559 _ => {
560 self.residuals.insert(expr.clone());
561 }
562 }
563
564 Ok(())
565 }
566
567 /// Add a binary expression to our collection of filters.
568 fn insert_binary_expr(&mut self, left: &Expr, op: Operator, right: &Expr) -> Result<()> {
569 match (left, op, right) {
570 (Expr::Column(c), op, Expr::Literal(v)) => {
571 if let Err(e) = self.add_range(c, &op, v) {
572 // Add a range can fail in some cases, so just fallthrough
573 log::debug!("failed to add range filter: {e}");
574 } else {
575 return Ok(());
576 }
577 }
578 (Expr::Literal(_), op, Expr::Column(_)) => {
579 if let Some(swapped) = op.swap() {
580 return self.insert_binary_expr(right, swapped, left);
581 }
582 }
583 // update eq classes & merge ranges by eq class
584 (Expr::Column(c1), Operator::Eq, Expr::Column(c2)) => {
585 self.add_equivalence(c1, c2)?;
586 return Ok(());
587 }
588 _ => {}
589 }
590
591 self.residuals.insert(Expr::BinaryExpr(BinaryExpr {
592 left: Box::new(left.clone()),
593 op,
594 right: Box::new(right.clone()),
595 }));
596
597 Ok(())
598 }
599
600 /// Test that all column equivalence classes of `other` are subsumed by one from `self`.
601 /// This is called the 'equijoin' subsumption test because column equivalences often
602 /// result from join predicates.
603 /// Returns any compensating column equality predicates that should be applied to
604 /// make this plan match the output of the other one.
605 fn equijoin_subsumption_test(&self, other: &Self) -> Option<Vec<Expr>> {
606 let mut new_equivalences = vec![];
607 // check that all equivalence classes of `other` are contained in one from `self`
608 for other_class in &other.eq_classes {
609 let (representative, eq_class) = match other_class
610 .columns
611 .iter()
612 .find_map(|c| self.class_for_column(c).map(|class| (c, class)))
613 {
614 // We don't contain any columns from this eq class.
615 // Technically this is alright if the equivalence class is trivial,
616 // because we're allowed to be a subset of the other plan.
617 // If the equivalence class is nontrivial then we can't compute
618 // compensating filters because we lack the columns that would be
619 // used in the filter.
620 None if other_class.columns.len() == 1 => continue,
621 // We do contain columns from this eq class.
622 Some(tuple) => tuple,
623 // We don't contain columns from this eq class and the
624 // class is nontrivial.
625 _ => return None,
626 };
627
628 if !other_class.columns.is_subset(&eq_class.columns) {
629 return None;
630 }
631
632 for column in eq_class.columns.difference(&other_class.columns) {
633 new_equivalences
634 .push(Expr::Column(representative.clone()).eq(Expr::Column(column.clone())))
635 }
636 }
637
638 log::trace!("passed equijoin subsumption test");
639
640 Some(new_equivalences)
641 }
642
643 /// Test that all range filters of `self` are contained in one from `other`.
644 /// This includes equality comparisons, which map to ranges of the form [v, v]
645 /// for some value v.
646 /// Returns any compensating range filters that should be applied to this plan
647 /// to make its output match the other one.
648 fn range_subsumption_test(&self, other: &Self) -> Result<Option<Vec<Expr>>> {
649 let mut extra_range_filters = vec![];
650 for (eq_class, range) in self
651 .eq_classes
652 .iter()
653 .zip(self.ranges_by_equivalence_class.iter())
654 {
655 let range = match range {
656 None => {
657 // empty; it's always contained in another range
658 // also this range is never satisfiable, so it's always False
659 extra_range_filters.push(lit(false));
660 continue;
661 }
662 Some(range) => range,
663 };
664
665 let (other_column, other_range) = match eq_class.columns.iter().find_map(|c| {
666 other.eq_class_idx_by_column.get(c).and_then(|&idx| {
667 other.ranges_by_equivalence_class[idx]
668 .as_ref()
669 .map(|range| (other.eq_classes[idx].columns.first().unwrap(), range))
670 })
671 }) {
672 None => return Ok(None),
673 Some(range) => range,
674 };
675
676 if other_range.contains(range)? != Interval::CERTAINLY_TRUE {
677 return Ok(None);
678 }
679
680 if range.contains(other_range)? != Interval::CERTAINLY_TRUE {
681 if !(range.lower().is_null() || range.upper().is_null())
682 && (range.lower().eq(range.upper()))
683 {
684 // Certain datafusion code paths only work if eq expressions are preserved
685 // that is, col >= val AND col <= val is not treated the same as col = val
686 // We special-case this to make sure everything works as expected.
687 // todo: could this be a logical optimizer?
688 extra_range_filters.push(Expr::BinaryExpr(BinaryExpr {
689 left: Box::new(Expr::Column(other_column.clone())),
690 op: Operator::Eq,
691 right: Box::new(Expr::Literal(range.lower().clone())),
692 }))
693 } else {
694 if !range.lower().is_null() {
695 extra_range_filters.push(Expr::BinaryExpr(BinaryExpr {
696 left: Box::new(Expr::Column(other_column.clone())),
697 op: Operator::GtEq,
698 right: Box::new(Expr::Literal(range.lower().clone())),
699 }))
700 }
701
702 if !range.upper().is_null() {
703 extra_range_filters.push(Expr::BinaryExpr(BinaryExpr {
704 left: Box::new(Expr::Column(other_column.clone())),
705 op: Operator::LtEq,
706 right: Box::new(Expr::Literal(range.upper().clone())),
707 }))
708 }
709 }
710 }
711 }
712
713 log::trace!("passed range subsumption test");
714
715 Ok(Some(extra_range_filters))
716 }
717
718 /// Test that any "residual" filters (not column equivalence or range filters) from
719 /// `other` have matching entries in `self`.
720 /// For example, a residual filter might look like `x * y > 100`, as this expression
721 /// is neither a column equivalence nor a range filter (importantly, not a range filter
722 /// directly on a column).)
723 /// This ensures that `self` is a subset of `other`.
724 /// Return any residual filters in this plan that are not in the other one.
725 fn residual_subsumption_test(&self, other: &Self) -> Option<Vec<Expr>> {
726 let [self_residuals, other_residuals] = [self.residuals.clone(), other.residuals.clone()]
727 .map(|set| {
728 set.into_iter()
729 .map(|r| self.normalize_expr(r.clone()))
730 .collect::<HashSet<Expr>>()
731 });
732
733 if !self_residuals.is_superset(&other_residuals) {
734 return None;
735 }
736
737 log::trace!("passed residual subsumption test");
738
739 Some(
740 self_residuals
741 .difference(&other.residuals)
742 .cloned()
743 .collect_vec(),
744 )
745 }
746
747 /// Rewrite all expressions in terms of their normal representatives
748 /// with respect to this predicate's equivalence classes.
749 fn normalize_expr(&self, e: Expr) -> Expr {
750 e.transform(&|e| {
751 let c = match e {
752 Expr::Column(c) => c,
753 Expr::Alias(alias) => return Ok(Transformed::yes(alias.expr.as_ref().clone())),
754 _ => return Ok(Transformed::no(e)),
755 };
756
757 if let Some(eq_class) = self.class_for_column(&c) {
758 Ok(Transformed::yes(Expr::Column(
759 eq_class.columns.first().unwrap().clone(),
760 )))
761 } else {
762 Ok(Transformed::no(Expr::Column(c)))
763 }
764 })
765 .map(|t| t.data)
766 // No chance of error since we never return Err -- this unwrap is safe
767 .unwrap()
768 }
769}
770
771/// A collection of columns that are all considered to be equivalent.
772/// In some cases we normalize expressions so that they use the "normal" representative
773/// in place of any other columns in the class.
774/// This normal representative is chosen arbitrarily.
775#[derive(Debug, Clone, Default)]
776struct ColumnEquivalenceClass {
777 // first element is the normal representative of the equivalence class
778 columns: BTreeSet<Column>,
779}
780
781impl ColumnEquivalenceClass {
782 fn new(columns: impl IntoIterator<Item = Column>) -> Self {
783 Self {
784 columns: BTreeSet::from_iter(columns),
785 }
786 }
787
788 fn new_singleton(column: Column) -> Self {
789 Self {
790 columns: BTreeSet::from([column]),
791 }
792 }
793}
794
795/// For each field in the plan's schema, get an expression that represents the field's definition.
796/// Furthermore, normalize all expressions so that the only column expressions refer to directly to tables,
797/// not alias subqueries or child plans.
798///
799/// This essentially is equivalent to rewriting the query as a projection against a cross join.
800fn get_output_exprs(plan: &LogicalPlan) -> Result<Vec<Expr>> {
801 use datafusion_expr::logical_plan::*;
802
803 let output_exprs = match plan {
804 // ignore filter, sort, and limit
805 // they don't change the schema or the definitions
806 LogicalPlan::Filter(_)
807 | LogicalPlan::Sort(_)
808 | LogicalPlan::Limit(_)
809 | LogicalPlan::Distinct(_) => return get_output_exprs(plan.inputs()[0]),
810 LogicalPlan::Projection(Projection { expr, .. }) => Ok(expr.clone()),
811 LogicalPlan::Aggregate(Aggregate {
812 group_expr,
813 aggr_expr,
814 ..
815 }) => Ok(Vec::from_iter(
816 group_expr.iter().chain(aggr_expr.iter()).cloned(),
817 )),
818 LogicalPlan::Window(Window {
819 input, window_expr, ..
820 }) => Ok(Vec::from_iter(
821 input
822 .schema()
823 .fields()
824 .iter()
825 .map(|field| Expr::Column(Column::new_unqualified(field.name())))
826 .chain(window_expr.iter().cloned()),
827 )),
828 // if it's a table scan, just exit early with explicit return
829 LogicalPlan::TableScan(table_scan) => {
830 return Ok(get_table_scan_columns(table_scan)?
831 .into_iter()
832 .map(Expr::Column)
833 .collect())
834 }
835 LogicalPlan::Unnest(unnest) => Ok(unnest
836 .schema
837 .columns()
838 .into_iter()
839 .map(Expr::Column)
840 .collect()),
841 LogicalPlan::Join(join) => Ok(join
842 .left
843 .schema()
844 .columns()
845 .into_iter()
846 .chain(join.right.schema().columns())
847 .map(Expr::Column)
848 .collect_vec()),
849 LogicalPlan::SubqueryAlias(sa) => return get_output_exprs(&sa.input),
850 _ => Err(DataFusionError::NotImplemented(format!(
851 "Logical plan not supported: {}",
852 plan.display()
853 ))),
854 }?;
855
856 flatten_exprs(output_exprs, plan)
857}
858
859/// Recursively normalize expressions so that any columns refer directly to tables and not subqueries.
860fn flatten_exprs(exprs: Vec<Expr>, parent: &LogicalPlan) -> Result<Vec<Expr>> {
861 if matches!(parent, LogicalPlan::TableScan(_)) {
862 return Ok(exprs);
863 }
864
865 let schemas = parent
866 .inputs()
867 .iter()
868 .map(|input| input.schema().as_ref())
869 .collect_vec();
870 let using_columns = parent.using_columns()?;
871
872 let output_exprs_by_child = parent
873 .inputs()
874 .into_iter()
875 .map(get_output_exprs)
876 .collect::<Result<Vec<_>>>()?;
877
878 exprs
879 .into_iter()
880 .map(|expr| {
881 expr.transform_up(&|e| match e {
882 // if the relation is None, it's a column referencing one of the child plans
883 // if the relation is Some, it's a column of a table (most likely) and can be ignored since it's a leaf node
884 // (technically it can also refer to an aliased subquery)
885 Expr::Column(col) => {
886 // Figure out which child the column belongs to
887 let col = {
888 let col = if let LogicalPlan::SubqueryAlias(sa) = parent {
889 // If the parent is an aliased subquery, with the alias 'x',
890 // any expressions of the form `x.column1`
891 // refer to `column` in the input
892 if col.relation.as_ref() == Some(&sa.alias) {
893 Column::new_unqualified(col.name)
894 } else {
895 // All other columns are assumed to be leaf nodes (direct references to tables)
896 return Ok(Transformed::no(Expr::Column(col)));
897 }
898 } else {
899 col
900 };
901
902 col.normalize_with_schemas_and_ambiguity_check(&[&schemas], &using_columns)?
903 };
904
905 // first schema that matches column
906 // the check from earlier ensures that this will always be Some
907 // and that there should be only one schema that matches
908 // (except if it is a USING column, in which case we can pick any match equivalently)
909 let (child_idx, expr_idx) = schemas
910 .iter()
911 .enumerate()
912 .find_map(|(schema_idx, schema)| {
913 Some(schema_idx).zip(schema.maybe_index_of_column(&col))
914 })
915 .unwrap();
916
917 Ok(Transformed::yes(
918 output_exprs_by_child[child_idx][expr_idx].clone(),
919 ))
920 }
921 _ => Ok(Transformed::no(e)),
922 })
923 .data()
924 })
925 .collect()
926}
927
928/// Return the columns output by this [`TableScan`].
929fn get_table_scan_columns(scan: &TableScan) -> Result<Vec<Column>> {
930 let fields = {
931 let mut schema = scan.source.schema().as_ref().clone();
932 if let Some(ref p) = scan.projection {
933 schema = schema.project(p)?;
934 }
935 schema.fields
936 };
937
938 Ok(fields
939 .into_iter()
940 .map(|field| Column::new(Some(scan.table_name.to_owned()), field.name()))
941 .collect())
942}
943
944#[cfg(test)]
945mod test {
946 use arrow::compute::concat_batches;
947 use datafusion::{datasource::provider_as_source, prelude::SessionContext};
948 use datafusion_common::{DataFusionError, Result};
949 use datafusion_sql::TableReference;
950
951 use super::SpjNormalForm;
952
953 async fn setup() -> Result<SessionContext> {
954 let ctx = SessionContext::new();
955
956 ctx.sql(
957 "CREATE TABLE t1 AS VALUES
958 ('2021', 3, 'A'),
959 ('2022', 4, 'B'),
960 ('2023', 5, 'C')",
961 )
962 .await
963 .map_err(|e| e.context("parse `t1` table ddl"))?
964 .collect()
965 .await?;
966
967 ctx.sql(
968 "CREATE TABLE example (
969 l_orderkey INT,
970 l_partkey INT,
971 l_shipdate DATE,
972 l_quantity DOUBLE,
973 l_extendedprice DOUBLE,
974 o_custkey INT,
975 o_orderkey INT,
976 o_orderdate DATE,
977 p_name VARCHAR,
978 p_partkey INT
979 )
980 ",
981 )
982 .await
983 .map_err(|e| e.context("parse `example` table ddl"))?
984 .collect()
985 .await?;
986
987 Ok(ctx)
988 }
989
990 struct TestCase {
991 name: &'static str,
992 base: &'static str,
993 query: &'static str,
994 }
995
996 async fn run_test(case: &TestCase) -> Result<()> {
997 let context = setup()
998 .await
999 .map_err(|e| e.context("setup test environment"))?;
1000
1001 let base_plan = context.sql(case.base).await?.into_optimized_plan()?; // Optimize plan to eliminate unnormalized wildcard exprs
1002 let base_normal_form = SpjNormalForm::new(&base_plan)?;
1003
1004 context
1005 .sql(&format!("CREATE TABLE mv AS {}", case.base))
1006 .await?
1007 .collect()
1008 .await?;
1009
1010 let query_plan = context.sql(case.query).await?.into_optimized_plan()?;
1011 let query_normal_form = SpjNormalForm::new(&query_plan)?;
1012
1013 let table_ref = TableReference::bare("mv");
1014 let rewritten = query_normal_form
1015 .rewrite_from(
1016 &base_normal_form,
1017 table_ref.clone(),
1018 provider_as_source(context.table_provider(table_ref).await?),
1019 )?
1020 .ok_or(DataFusionError::Internal(
1021 "expected rewrite to succeed".to_string(),
1022 ))?;
1023
1024 assert_eq!(rewritten.schema().as_ref(), query_plan.schema().as_ref());
1025
1026 for plan in [&base_plan, &query_plan, &rewritten] {
1027 context
1028 .execute_logical_plan(plan.clone())
1029 .await?
1030 .explain(false, false)?
1031 .show()
1032 .await?;
1033 }
1034
1035 let expected = concat_batches(
1036 &query_plan.schema().as_ref().clone().into(),
1037 &context
1038 .execute_logical_plan(query_plan)
1039 .await?
1040 .collect()
1041 .await?,
1042 )?;
1043
1044 let result = concat_batches(
1045 &rewritten.schema().as_ref().clone().into(),
1046 &context
1047 .execute_logical_plan(rewritten)
1048 .await?
1049 .collect()
1050 .await?,
1051 )?;
1052
1053 assert_eq!(result, expected);
1054
1055 Ok(())
1056 }
1057
1058 #[tokio::test]
1059 async fn test_rewrite() -> Result<()> {
1060 let _ = env_logger::builder().is_test(true).try_init();
1061 let cases = vec![
1062 TestCase {
1063 name: "simple selection",
1064 base: "SELECT * FROM t1",
1065 query: "SELECT column1, column2 FROM t1",
1066 },
1067 TestCase {
1068 name: "selection with equality predicate",
1069 base: "SELECT * FROM t1",
1070 query: "SELECT column1, column2 FROM t1 WHERE column1 = column3",
1071 },
1072 TestCase {
1073 name: "selection with range filter",
1074 base: "SELECT * FROM t1 WHERE column2 > 3",
1075 query: "SELECT column1, column2 FROM t1 WHERE column2 > 4",
1076 },
1077 TestCase {
1078 name: "nontrivial projection",
1079 base: "SELECT concat(column1, column2), column2 FROM t1",
1080 query: "SELECT concat(column1, column2) FROM t1",
1081 },
1082 TestCase {
1083 name: "range filter + equality predicate",
1084 base:
1085 "SELECT column1, column2 FROM t1 WHERE column1 = column3 AND column1 >= '2022'",
1086 query:
1087 // Since column1 = column3 in the original view,
1088 // we are allowed to substitute column1 for column3 and vice versa.
1089 "SELECT column2, column3 FROM t1 WHERE column1 = column3 AND column3 >= '2023'",
1090 },
1091 TestCase {
1092 name: "duplicate expressions (X-209)",
1093 base: "SELECT * FROM t1",
1094 query:
1095 "SELECT column1, NULL AS column2, NULL AS column3, column3 AS column4 FROM t1",
1096 },
1097 TestCase {
1098 name: "example from paper",
1099 base: "\
1100 SELECT
1101 l_orderkey,
1102 o_custkey,
1103 l_partkey,
1104 l_shipdate, o_orderdate,
1105 l_quantity*l_extendedprice AS gross_revenue
1106 FROM example
1107 WHERE
1108 l_orderkey = o_orderkey AND
1109 l_partkey = p_partkey AND
1110 p_partkey >= 150 AND
1111 o_custkey >= 50 AND
1112 o_custkey <= 500 AND
1113 p_name LIKE '%abc%'
1114 ",
1115 query: "SELECT
1116 l_orderkey,
1117 o_custkey,
1118 l_partkey,
1119 l_quantity*l_extendedprice
1120 FROM example
1121 WHERE
1122 l_orderkey = o_orderkey AND
1123 l_partkey = p_partkey AND
1124 l_partkey >= 150 AND
1125 l_partkey <= 160 AND
1126 o_custkey = 123 AND
1127 o_orderdate = l_shipdate AND
1128 p_name like '%abc%' AND
1129 l_quantity*l_extendedprice > 100
1130 ",
1131 },
1132 ];
1133
1134 for case in cases {
1135 println!("executing test: {}", case.name);
1136 run_test(&case).await.map_err(|e| e.context(case.name))?;
1137 }
1138
1139 Ok(())
1140 }
1141}