datafusion_materialized_views/rewrite/
normal_form.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18/*!
19
20This module contains code primarily used for view matching. We implement the view matching algorithm from [this paper](https://courses.cs.washington.edu/courses/cse591d/01sp/opt_views.pdf),
21which provides a method for determining when one Select-Project-Join query can be rewritten in terms of another Select-Project-Join query.
22
23The implementation is contained in [`SpjNormalForm::rewrite_from`]. The method can be summarized as follows:
241. Compute column equivalence classes for the query and the view.
252. Compute range intervals for the query and the view.
263. (Equijoin subsumption test) Check that each column equivalence class of the view is a subset of a column equivalence class of the query.
274. (Range subsumption test) Check that each range of the view contains the corresponding range from the query.
285. (Residual subsumption test) Check that every filter in the view that is not a column equivalence relation or a range filter matches a filter from the query.
296. Compute any compensating filters needed in order to restrict the view's rows to match the query.
307. Check that the output of the query, and the compensating filters, can be rewritten using the view's columns as inputs.
31
32# Example
33
34Consider the following table:
35
36```sql
37CREATE TABLE example (
38    l_orderkey INT,
39    l_partkey INT,
40    l_shipdate DATE,
41    l_quantity DOUBLE,
42    l_extendedprice DOUBLE,
43    o_custkey INT,
44    o_orderkey INT,
45    o_orderdate DATE,
46    p_name VARCHAR,
47    p_partkey INT,
48)
49```
50
51And consider the follow view:
52
53```sql
54CREATE VIEW mv AS SELECT
55    l_orderkey,
56    o_custkey,
57    l_partkey,
58    l_shipdate, o_orderdate,
59    l_quantity*l_extendedprice AS gross_revenue
60FROM example
61WHERE
62    l_orderkey = o_orderkey AND
63    l_partkey = p_partkey AND
64    p_partkey >= 150 AND
65    o_custkey >= 50 AND
66    o_custkey <= 500 AND
67    p_name LIKE '%abc%'
68```
69
70During analysis, we look at the implied equivalence classes and possible range of values for each equivalence class.
71For this view, the following nontrivial equivalence classes are generated:
72 * `{l_orderkey, o_orderkey}`
73 * `{l_partkey, p_partkey}`
74
75Note that all other columns have their own singleton equivalence classes, but are not shown here.
76Likewise, the following nontrivial ranges are generated:
77 * `150 <= {l_partkey, p_partkey} < inf`
78 * `50 <= {o_custkey} <= 500`
79
80The rest of the equivalence classes are considered to have ranges of (-inf, inf).
81The remaining filter `p_name LIKE '%abc%'` is considered 'residual' as it is not a column equivalence nor a range filter.
82
83Now consider the following query, which we will rewrite to use the view:
84
85```sql
86SELECT
87    l_orderkey,
88    o_custkey,
89    l_partkey,
90    l_quantity*l_extendedprice
91FROM example
92WHERE
93    l_orderkey = o_orderkey AND
94    l_partkey = p_partkey AND
95    l_partkey >= 150 AND
96    l_partkey <= 160 AND
97    o_custkey = 123 AND
98    o_orderdate = l_shipdate AND
99    p_name like '%abc%' AND
100    l_quantity*l_extendedprice > 100
101````
102
103This generates the following equivalence classes:
104 * `{l_orderkey, o_orderkey}`
105 * `{l_partkey, p_partkey}`
106 * `{o_orderdate, l_shipdate}`
107
108And the following ranges:
109 * `150 <= {l_partkey, p_partkey} <= 160`
110 * `123 <= {o_custkey} <= 123`
111
112As before, we still have the residual filter `p_name LIKE '%abc'`. However, note that `l_quantity*l_extendedprice > 100` is also
113a residual filter, as it is not a range filter on a column -- it's a range filter on a mathematical expression.
114
115We perform the three subsumption tests:
116 * Equijoin subsumption test:
117   * View equivalence classes: `{l_orderkey, o_orderkey}, {l_partkey, p_partkey}`
118   * Query equivalence classes: `{l_orderkey, o_orderkey}, {l_partkey, p_partkey}, {o_orderdate, l_shipdate}`
119   * Every view equivalence class is a subset of one from the query, so the test is passed.
120   * We generate the compensating filter `o_orderdate = l_shipdate`.
121 * Range subsumption test:
122   * View ranges:
123     * `150 <= {l_partkey, p_partkey} < inf`
124     * `50 <= {o_custkey} <= 500`
125   * Query ranges:
126     * `150 <= {l_partkey, p_partkey} <= 160`
127     * `123 <= {o_custkey} <= 123`
128   * Both of the view's ranges contain corresponding ranges from the query, therefore the test is passed.
129   * Since they're both strict inclusions, we include them both as compensating filters.
130 * Residual subsumption test:
131   * View residuals: `p_name LIKE '%abc'`
132   * Query residuals: `p_name LIKE '%abc'`, `l_quantity*l_extendedprice > 100`
133   * Every view residual has a matching residual from the query, and the test is passed.
134   * The leftover residual in the query, `l_quantity*l_extendedprice > 100`, is included as a compensating filter.
135
136Ultimately we have the following compensating filters:
137 * `o_orderdate = l_shipdate`
138 * `150 <= {l_partkey, p_partkey} <= 160`
139 * `123 <= {o_custkey} <= 123`
140 * `l_quantity*l_extendedprice > 100`
141
142The final check is to ensure that the output of our query can be computed from the view. This includes
143any expressions used in the compensating filters.
144This is a relatively simple check that mostly involves rewriting expressions to use columns from the view,
145and ensuring that no references to the original tables are left.
146
147This example is included as a unit test. After rewriting the query to use the view, the resulting plan looks like this:
148
149```text
150+---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
151| plan_type     | plan                                                                                                                                                                                                     |
152+---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
153| logical_plan  | Projection: mv.l_orderkey AS l_orderkey, mv.o_custkey AS o_custkey, mv.l_partkey AS l_partkey, mv.gross_revenue AS example.l_quantity * example.l_extendedprice                                          |
154|               |   Filter: mv.o_orderdate = mv.l_shipdate AND mv.l_partkey >= Int32(150) AND mv.l_partkey <= Int32(160) AND mv.o_custkey >= Int32(123) AND mv.o_custkey <= Int32(123) AND mv.gross_revenue > Float64(100) |
155|               |     TableScan: mv projection=[l_orderkey, o_custkey, l_partkey, l_shipdate, o_orderdate, gross_revenue]                                                                                                  |
156| physical_plan | ProjectionExec: expr=[l_orderkey@0 as l_orderkey, o_custkey@1 as o_custkey, l_partkey@2 as l_partkey, gross_revenue@5 as example.l_quantity * example.l_extendedprice]                                   |
157|               |   CoalesceBatchesExec: target_batch_size=8192                                                                                                                                                            |
158|               |     FilterExec: o_orderdate@4 = l_shipdate@3 AND l_partkey@2 >= 150 AND l_partkey@2 <= 160 AND o_custkey@1 >= 123 AND o_custkey@1 <= 123 AND gross_revenue@5 > 100                                       |
159|               |       MemoryExec: partitions=16, partition_sizes=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]                                                                                                        |
160|               |                                                                                                                                                                                                          |
161+---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
162```
163
164As one can see, all compensating filters are included, and the query only uses the view.
165
166*/
167
168use std::{
169    collections::{BTreeSet, HashMap, HashSet},
170    sync::Arc,
171};
172
173use datafusion_common::{
174    tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter},
175    Column, DFSchema, DataFusionError, ExprSchema, Result, ScalarValue, TableReference,
176};
177use datafusion_expr::{
178    interval_arithmetic::{satisfy_greater, Interval},
179    lit,
180    utils::split_conjunction,
181    BinaryExpr, Expr, LogicalPlan, LogicalPlanBuilder, Operator, TableScan, TableSource,
182};
183use itertools::Itertools;
184
185/// A normalized representation of a plan containing only Select/Project/Join in the relational algebra sense.
186/// In DataFusion terminology this also includes Filter nodes.
187/// Joins are not currently supported, but are planned.
188#[derive(Debug, Clone)]
189pub struct SpjNormalForm {
190    output_schema: Arc<DFSchema>,
191    output_exprs: Vec<Expr>,
192    referenced_tables: Vec<TableReference>,
193    predicate: Predicate,
194}
195
196/// Rewrite an expression to re-use output columns from this plan, where possible.
197impl TreeNodeRewriter for &SpjNormalForm {
198    type Node = Expr;
199
200    fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
201        Ok(match self.output_exprs.iter().position(|x| x == &node) {
202            Some(idx) => Transformed::yes(Expr::Column(Column::new_unqualified(
203                self.output_schema.field(idx).name().clone(),
204            ))),
205            None => Transformed::no(node),
206        })
207    }
208}
209
210impl SpjNormalForm {
211    /// Schema of data output by this plan.
212    pub fn output_schema(&self) -> &Arc<DFSchema> {
213        &self.output_schema
214    }
215
216    /// Expressions output by this plan.
217    /// These expressions can be used to rewrite this plan as a cross join followed by a projection;
218    /// however, this does not include any filters in the original plan, so the result will be a superset.
219    pub fn output_exprs(&self) -> &[Expr] {
220        &self.output_exprs
221    }
222
223    /// All tables referenced in this plan.
224    pub fn referenced_tables(&self) -> &[TableReference] {
225        &self.referenced_tables
226    }
227
228    /// Analyze an existing `LogicalPlan` and rewrite it in select-project-join normal form.
229    pub fn new(original_plan: &LogicalPlan) -> Result<Self> {
230        let predicate = Predicate::new(original_plan)?;
231        let output_exprs = get_output_exprs(original_plan)?
232            .into_iter()
233            .map(|expr| predicate.normalize_expr(expr))
234            .collect();
235
236        let mut referenced_tables = vec![];
237        original_plan
238            .apply(|plan| {
239                if let LogicalPlan::TableScan(scan) = plan {
240                    referenced_tables.push(scan.table_name.clone());
241                }
242
243                Ok(TreeNodeRecursion::Continue)
244            })
245            // No chance of error since we never return Err -- this unwrap is safe
246            .unwrap();
247
248        Ok(Self {
249            output_schema: Arc::clone(original_plan.schema()),
250            output_exprs,
251            referenced_tables,
252            predicate,
253        })
254    }
255
256    /// Rewrite this plan as as selection/projection on top of another plan,
257    /// which we use `qualifier` to refer to.
258    /// This is useful for rewriting queries to use materialized views.
259    pub fn rewrite_from(
260        &self,
261        mut other: &Self,
262        qualifier: TableReference,
263        source: Arc<dyn TableSource>,
264    ) -> Result<Option<LogicalPlan>> {
265        log::trace!("rewriting from {qualifier}");
266        let mut new_output_exprs = vec![];
267        // check that our output exprs are sub-expressions of the other one's output exprs
268        for (i, output_expr) in self.output_exprs.iter().enumerate() {
269            let new_output_expr = other
270                .predicate
271                .normalize_expr(output_expr.clone())
272                .rewrite(&mut other)?
273                .data;
274
275            // Check that all references to the original tables have been replaced.
276            // All remaining column expressions should be unqualified, which indicates
277            // that they refer to the output of the sub-plan (in this case the view)
278            if new_output_expr
279                .column_refs()
280                .iter()
281                .any(|c| c.relation.is_some())
282            {
283                return Ok(None);
284            }
285
286            let column = &self.output_schema.columns()[i];
287            new_output_exprs.push(
288                new_output_expr.alias_qualified(column.relation.clone(), column.name.clone()),
289            );
290        }
291
292        log::trace!("passed output rewrite");
293
294        // Check the subsumption tests, and compute any auxiliary needed filter expressions.
295        // If we pass all three subsumption tests, this plan's output is a subset of the other
296        // plan's output.
297        let ((eq_filters, range_filters), residual_filters) = match self
298            .predicate
299            .equijoin_subsumption_test(&other.predicate)
300            .zip(self.predicate.range_subsumption_test(&other.predicate)?)
301            .zip(self.predicate.residual_subsumption_test(&other.predicate))
302        {
303            None => return Ok(None),
304            Some(filters) => filters,
305        };
306
307        log::trace!("passed subsumption tests");
308
309        let all_filters = eq_filters
310            .into_iter()
311            .chain(range_filters)
312            .chain(residual_filters)
313            .map(|expr| expr.rewrite(&mut other).unwrap().data)
314            .reduce(|a, b| a.and(b));
315
316        if all_filters
317            .as_ref()
318            .map(|expr| expr.column_refs())
319            .is_some_and(|columns| columns.iter().any(|c| c.relation.is_some()))
320        {
321            return Ok(None);
322        }
323
324        let mut builder = LogicalPlanBuilder::scan(qualifier, source, None)?;
325
326        if let Some(filter) = all_filters {
327            builder = builder.filter(filter)?;
328        }
329
330        builder.project(new_output_exprs)?.build().map(Some)
331    }
332}
333
334/// Stores information on filters from a Select-Project-Join plan.
335#[derive(Debug, Clone)]
336struct Predicate {
337    schema: DFSchema,
338    /// List of column equivalence classes.
339    eq_classes: Vec<ColumnEquivalenceClass>,
340    /// Reverse lookup by eq class elements
341    eq_class_idx_by_column: HashMap<Column, usize>,
342    /// Stores (possibly empty) intervals describing each equivalence class.
343    ranges_by_equivalence_class: Vec<Option<Interval>>,
344    /// Filter expressions that aren't column equality predicates or range filters.
345    residuals: HashSet<Expr>,
346}
347
348impl Predicate {
349    fn new(plan: &LogicalPlan) -> Result<Self> {
350        let mut schema = DFSchema::empty();
351        plan.apply(|plan| {
352            if let LogicalPlan::TableScan(scan) = plan {
353                schema = schema.join(&scan.projected_schema)?;
354            }
355
356            Ok(TreeNodeRecursion::Continue)
357        })?;
358
359        let mut new = Self {
360            schema,
361            eq_classes: vec![],
362            eq_class_idx_by_column: HashMap::default(),
363            ranges_by_equivalence_class: vec![],
364            residuals: HashSet::new(),
365        };
366
367        // Collect all referenced columns
368        plan.apply(|plan| {
369            if let LogicalPlan::TableScan(scan) = plan {
370                for (i, column) in scan.projected_schema.columns().iter().enumerate() {
371                    new.eq_classes
372                        .push(ColumnEquivalenceClass::new_singleton(column.clone()));
373                    new.eq_class_idx_by_column.insert(column.clone(), i);
374                    new.ranges_by_equivalence_class
375                        .push(Some(Interval::make_unbounded(
376                            scan.projected_schema.data_type(column)?,
377                        )?));
378                }
379            }
380
381            Ok(TreeNodeRecursion::Continue)
382        })?;
383
384        // Collect any filters
385        plan.apply(|plan| {
386            let filters = match plan {
387                LogicalPlan::TableScan(scan) => scan.filters.as_slice(),
388                LogicalPlan::Filter(filter) => core::slice::from_ref(&filter.predicate),
389                LogicalPlan::Join(_join) => {
390                    return Err(DataFusionError::Internal(
391                        "joins are not supported yet".to_string(),
392                    ))
393                }
394                LogicalPlan::Projection(_) => &[],
395                _ => {
396                    return Err(DataFusionError::Plan(format!(
397                        "unsupported logical plan: {}",
398                        plan.display()
399                    )))
400                }
401            };
402
403            for expr in filters.iter().flat_map(split_conjunction) {
404                new.insert_conjuct(expr)?;
405            }
406
407            Ok(TreeNodeRecursion::Continue)
408        })?;
409
410        Ok(new)
411    }
412
413    fn class_for_column(&self, col: &Column) -> Option<&ColumnEquivalenceClass> {
414        self.eq_class_idx_by_column
415            .get(col)
416            .and_then(|&idx| self.eq_classes.get(idx))
417    }
418
419    /// Add a new column equivalence
420    fn add_equivalence(&mut self, c1: &Column, c2: &Column) -> Result<()> {
421        match (
422            self.eq_class_idx_by_column.get(c1),
423            self.eq_class_idx_by_column.get(c2),
424        ) {
425            (None, None) => {
426                // Make a new eq class [c1, c2]
427                self.eq_classes
428                    .push(ColumnEquivalenceClass::new([c1.clone(), c2.clone()]));
429                self.ranges_by_equivalence_class
430                    .push(Some(Interval::make_unbounded(
431                        self.schema.field_from_column(c1).unwrap().data_type(),
432                    )?));
433            }
434
435            // These two cases are just adding a column to an existing class
436            (None, Some(&idx)) => {
437                self.eq_classes[idx].columns.insert(c1.clone());
438            }
439            (Some(&idx), None) => {
440                self.eq_classes[idx].columns.insert(c2.clone());
441            }
442            (Some(&i), Some(&j)) => {
443                // We need to merge two existing column eq classes.
444
445                // Delete the eq class with a larger index,
446                // so that the other one has its position preserved.
447                // Not necessary, but it's just a little simpler this way
448                let (i, j) = if i < j { (i, j) } else { (j, i) };
449
450                // Merge the deleted eq class with its new partner
451                let new_columns = self.eq_classes.remove(j).columns;
452                self.eq_classes[i].columns.extend(new_columns.clone());
453                for column in new_columns {
454                    self.eq_class_idx_by_column.insert(column, i);
455                }
456                // update all moved entries
457                for idx in self.eq_class_idx_by_column.values_mut() {
458                    if *idx > j {
459                        *idx -= 1;
460                    }
461                }
462
463                // Merge ranges
464                // Now that we know the two equivalence classes are equal,
465                // the new range is the intersection of the existing two ranges.
466                self.ranges_by_equivalence_class[i] = self.ranges_by_equivalence_class[i]
467                    .clone()
468                    .zip(self.ranges_by_equivalence_class.remove(j))
469                    .and_then(|(range, other_range)| range.intersect(other_range).transpose())
470                    .transpose()?;
471            }
472        }
473
474        Ok(())
475    }
476
477    /// Update range for a column's equivalence class
478    fn add_range(&mut self, c: &Column, op: &Operator, value: &ScalarValue) -> Result<()> {
479        // first coerce the value if needed
480        let value = value.cast_to(self.schema.data_type(c)?)?;
481        let range = self
482            .eq_class_idx_by_column
483            .get(c)
484            .and_then(|&idx| self.ranges_by_equivalence_class.get_mut(idx))
485            .unwrap();
486        let new_range = match op {
487            Operator::Eq => Interval::try_new(value.clone(), value.clone()),
488            Operator::LtEq => {
489                Interval::try_new(ScalarValue::try_from(value.data_type())?, value.clone())
490            }
491            Operator::GtEq => {
492                Interval::try_new(value.clone(), ScalarValue::try_from(value.data_type())?)
493            }
494
495            // Note: This is a roundabout way (read: hack) to construct an open Interval.
496            // DataFusion's Interval type represents closed intervals,
497            // so handling of open intervals is done by adding/subtracting the smallest increment.
498            // However, there is not really a public API to do this,
499            // other than the satisfy_greater method.
500            Operator::Lt => Ok(
501                match satisfy_greater(
502                    &Interval::try_new(value.clone(), value.clone())?,
503                    &Interval::make_unbounded(&value.data_type())?,
504                    true,
505                )? {
506                    Some((_, range)) => range,
507                    None => {
508                        *range = None;
509                        return Ok(());
510                    }
511                },
512            ),
513            // Same thing as above.
514            Operator::Gt => Ok(
515                match satisfy_greater(
516                    &Interval::make_unbounded(&value.data_type())?,
517                    &Interval::try_new(value.clone(), value.clone())?,
518                    true,
519                )? {
520                    Some((range, _)) => range,
521                    None => {
522                        *range = None;
523                        return Ok(());
524                    }
525                },
526            ),
527            _ => Err(DataFusionError::Plan(
528                "unsupported binary expression".to_string(),
529            )),
530        }?;
531
532        *range = match range {
533            None => Some(new_range),
534            Some(range) => range.intersect(new_range)?,
535        };
536
537        Ok(())
538    }
539
540    /// Add a generic filter expression to our collection of filters.
541    /// A conjunct is a term T_i of an expression T_1 AND T_2 AND T_3 AND ...
542    fn insert_conjuct(&mut self, expr: &Expr) -> Result<()> {
543        match expr {
544            Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
545                self.insert_binary_expr(left, *op, right)?;
546            }
547            Expr::Not(e) => match e.as_ref() {
548                Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
549                    if let Some(negated) = op.negate() {
550                        self.insert_binary_expr(left, negated, right)?;
551                    } else {
552                        self.residuals.insert(expr.clone());
553                    }
554                }
555                _ => {
556                    self.residuals.insert(expr.clone());
557                }
558            },
559            _ => {
560                self.residuals.insert(expr.clone());
561            }
562        }
563
564        Ok(())
565    }
566
567    /// Add a binary expression to our collection of filters.
568    fn insert_binary_expr(&mut self, left: &Expr, op: Operator, right: &Expr) -> Result<()> {
569        match (left, op, right) {
570            (Expr::Column(c), op, Expr::Literal(v)) => {
571                if let Err(e) = self.add_range(c, &op, v) {
572                    // Add a range can fail in some cases, so just fallthrough
573                    log::debug!("failed to add range filter: {e}");
574                } else {
575                    return Ok(());
576                }
577            }
578            (Expr::Literal(_), op, Expr::Column(_)) => {
579                if let Some(swapped) = op.swap() {
580                    return self.insert_binary_expr(right, swapped, left);
581                }
582            }
583            // update eq classes & merge ranges by eq class
584            (Expr::Column(c1), Operator::Eq, Expr::Column(c2)) => {
585                self.add_equivalence(c1, c2)?;
586                return Ok(());
587            }
588            _ => {}
589        }
590
591        self.residuals.insert(Expr::BinaryExpr(BinaryExpr {
592            left: Box::new(left.clone()),
593            op,
594            right: Box::new(right.clone()),
595        }));
596
597        Ok(())
598    }
599
600    /// Test that all column equivalence classes of `other` are subsumed by one from `self`.
601    /// This is called the 'equijoin' subsumption test because column equivalences often
602    /// result from join predicates.
603    /// Returns any compensating column equality predicates that should be applied to
604    /// make this plan match the output of the other one.
605    fn equijoin_subsumption_test(&self, other: &Self) -> Option<Vec<Expr>> {
606        let mut new_equivalences = vec![];
607        // check that all equivalence classes of `other` are contained in one from `self`
608        for other_class in &other.eq_classes {
609            let (representative, eq_class) = match other_class
610                .columns
611                .iter()
612                .find_map(|c| self.class_for_column(c).map(|class| (c, class)))
613            {
614                // We don't contain any columns from this eq class.
615                // Technically this is alright if the equivalence class is trivial,
616                // because we're allowed to be a subset of the other plan.
617                // If the equivalence class is nontrivial then we can't compute
618                // compensating filters because we lack the columns that would be
619                // used in the filter.
620                None if other_class.columns.len() == 1 => continue,
621                // We do contain columns from this eq class.
622                Some(tuple) => tuple,
623                // We don't contain columns from this eq class and the
624                // class is nontrivial.
625                _ => return None,
626            };
627
628            if !other_class.columns.is_subset(&eq_class.columns) {
629                return None;
630            }
631
632            for column in eq_class.columns.difference(&other_class.columns) {
633                new_equivalences
634                    .push(Expr::Column(representative.clone()).eq(Expr::Column(column.clone())))
635            }
636        }
637
638        log::trace!("passed equijoin subsumption test");
639
640        Some(new_equivalences)
641    }
642
643    /// Test that all range filters of `self` are contained in one from `other`.
644    /// This includes equality comparisons, which map to ranges of the form [v, v]
645    /// for some value v.
646    /// Returns any compensating range filters that should be applied to this plan
647    /// to make its output match the other one.
648    fn range_subsumption_test(&self, other: &Self) -> Result<Option<Vec<Expr>>> {
649        let mut extra_range_filters = vec![];
650        for (eq_class, range) in self
651            .eq_classes
652            .iter()
653            .zip(self.ranges_by_equivalence_class.iter())
654        {
655            let range = match range {
656                None => {
657                    // empty; it's always contained in another range
658                    // also this range is never satisfiable, so it's always False
659                    extra_range_filters.push(lit(false));
660                    continue;
661                }
662                Some(range) => range,
663            };
664
665            let (other_column, other_range) = match eq_class.columns.iter().find_map(|c| {
666                other.eq_class_idx_by_column.get(c).and_then(|&idx| {
667                    other.ranges_by_equivalence_class[idx]
668                        .as_ref()
669                        .map(|range| (other.eq_classes[idx].columns.first().unwrap(), range))
670                })
671            }) {
672                None => return Ok(None),
673                Some(range) => range,
674            };
675
676            if other_range.contains(range)? != Interval::CERTAINLY_TRUE {
677                return Ok(None);
678            }
679
680            if range.contains(other_range)? != Interval::CERTAINLY_TRUE {
681                if !(range.lower().is_null() || range.upper().is_null())
682                    && (range.lower().eq(range.upper()))
683                {
684                    // Certain datafusion code paths only work if eq expressions are preserved
685                    // that is, col >= val AND col <= val is not treated the same as col = val
686                    // We special-case this to make sure everything works as expected.
687                    // todo: could this be a logical optimizer?
688                    extra_range_filters.push(Expr::BinaryExpr(BinaryExpr {
689                        left: Box::new(Expr::Column(other_column.clone())),
690                        op: Operator::Eq,
691                        right: Box::new(Expr::Literal(range.lower().clone())),
692                    }))
693                } else {
694                    if !range.lower().is_null() {
695                        extra_range_filters.push(Expr::BinaryExpr(BinaryExpr {
696                            left: Box::new(Expr::Column(other_column.clone())),
697                            op: Operator::GtEq,
698                            right: Box::new(Expr::Literal(range.lower().clone())),
699                        }))
700                    }
701
702                    if !range.upper().is_null() {
703                        extra_range_filters.push(Expr::BinaryExpr(BinaryExpr {
704                            left: Box::new(Expr::Column(other_column.clone())),
705                            op: Operator::LtEq,
706                            right: Box::new(Expr::Literal(range.upper().clone())),
707                        }))
708                    }
709                }
710            }
711        }
712
713        log::trace!("passed range subsumption test");
714
715        Ok(Some(extra_range_filters))
716    }
717
718    /// Test that any "residual" filters (not column equivalence or range filters) from
719    /// `other` have matching entries in `self`.
720    /// For example, a residual filter might look like `x * y > 100`, as this expression
721    /// is neither a column equivalence nor a range filter (importantly, not a range filter
722    /// directly on a column).)
723    /// This ensures that `self` is a subset of `other`.
724    /// Return any residual filters in this plan that are not in the other one.
725    fn residual_subsumption_test(&self, other: &Self) -> Option<Vec<Expr>> {
726        let [self_residuals, other_residuals] = [self.residuals.clone(), other.residuals.clone()]
727            .map(|set| {
728                set.into_iter()
729                    .map(|r| self.normalize_expr(r.clone()))
730                    .collect::<HashSet<Expr>>()
731            });
732
733        if !self_residuals.is_superset(&other_residuals) {
734            return None;
735        }
736
737        log::trace!("passed residual subsumption test");
738
739        Some(
740            self_residuals
741                .difference(&other.residuals)
742                .cloned()
743                .collect_vec(),
744        )
745    }
746
747    /// Rewrite all expressions in terms of their normal representatives
748    /// with respect to this predicate's equivalence classes.
749    fn normalize_expr(&self, e: Expr) -> Expr {
750        e.transform(&|e| {
751            let c = match e {
752                Expr::Column(c) => c,
753                Expr::Alias(alias) => return Ok(Transformed::yes(alias.expr.as_ref().clone())),
754                _ => return Ok(Transformed::no(e)),
755            };
756
757            if let Some(eq_class) = self.class_for_column(&c) {
758                Ok(Transformed::yes(Expr::Column(
759                    eq_class.columns.first().unwrap().clone(),
760                )))
761            } else {
762                Ok(Transformed::no(Expr::Column(c)))
763            }
764        })
765        .map(|t| t.data)
766        // No chance of error since we never return Err -- this unwrap is safe
767        .unwrap()
768    }
769}
770
771/// A collection of columns that are all considered to be equivalent.
772/// In some cases we normalize expressions so that they use the "normal" representative
773/// in place of any other columns in the class.
774/// This normal representative is chosen arbitrarily.
775#[derive(Debug, Clone, Default)]
776struct ColumnEquivalenceClass {
777    // first element is the normal representative of the equivalence class
778    columns: BTreeSet<Column>,
779}
780
781impl ColumnEquivalenceClass {
782    fn new(columns: impl IntoIterator<Item = Column>) -> Self {
783        Self {
784            columns: BTreeSet::from_iter(columns),
785        }
786    }
787
788    fn new_singleton(column: Column) -> Self {
789        Self {
790            columns: BTreeSet::from([column]),
791        }
792    }
793}
794
795/// For each field in the plan's schema, get an expression that represents the field's definition.
796/// Furthermore, normalize all expressions so that the only column expressions refer to directly to tables,
797/// not alias subqueries or child plans.
798///
799/// This essentially is equivalent to rewriting the query as a projection against a cross join.
800fn get_output_exprs(plan: &LogicalPlan) -> Result<Vec<Expr>> {
801    use datafusion_expr::logical_plan::*;
802
803    let output_exprs = match plan {
804        // ignore filter, sort, and limit
805        // they don't change the schema or the definitions
806        LogicalPlan::Filter(_)
807        | LogicalPlan::Sort(_)
808        | LogicalPlan::Limit(_)
809        | LogicalPlan::Distinct(_) => return get_output_exprs(plan.inputs()[0]),
810        LogicalPlan::Projection(Projection { expr, .. }) => Ok(expr.clone()),
811        LogicalPlan::Aggregate(Aggregate {
812            group_expr,
813            aggr_expr,
814            ..
815        }) => Ok(Vec::from_iter(
816            group_expr.iter().chain(aggr_expr.iter()).cloned(),
817        )),
818        LogicalPlan::Window(Window {
819            input, window_expr, ..
820        }) => Ok(Vec::from_iter(
821            input
822                .schema()
823                .fields()
824                .iter()
825                .map(|field| Expr::Column(Column::new_unqualified(field.name())))
826                .chain(window_expr.iter().cloned()),
827        )),
828        // if it's a table scan, just exit early with explicit return
829        LogicalPlan::TableScan(table_scan) => {
830            return Ok(get_table_scan_columns(table_scan)?
831                .into_iter()
832                .map(Expr::Column)
833                .collect())
834        }
835        LogicalPlan::Unnest(unnest) => Ok(unnest
836            .schema
837            .columns()
838            .into_iter()
839            .map(Expr::Column)
840            .collect()),
841        LogicalPlan::Join(join) => Ok(join
842            .left
843            .schema()
844            .columns()
845            .into_iter()
846            .chain(join.right.schema().columns())
847            .map(Expr::Column)
848            .collect_vec()),
849        LogicalPlan::SubqueryAlias(sa) => return get_output_exprs(&sa.input),
850        _ => Err(DataFusionError::NotImplemented(format!(
851            "Logical plan not supported: {}",
852            plan.display()
853        ))),
854    }?;
855
856    flatten_exprs(output_exprs, plan)
857}
858
859/// Recursively normalize expressions so that any columns refer directly to tables and not subqueries.
860fn flatten_exprs(exprs: Vec<Expr>, parent: &LogicalPlan) -> Result<Vec<Expr>> {
861    if matches!(parent, LogicalPlan::TableScan(_)) {
862        return Ok(exprs);
863    }
864
865    let schemas = parent
866        .inputs()
867        .iter()
868        .map(|input| input.schema().as_ref())
869        .collect_vec();
870    let using_columns = parent.using_columns()?;
871
872    let output_exprs_by_child = parent
873        .inputs()
874        .into_iter()
875        .map(get_output_exprs)
876        .collect::<Result<Vec<_>>>()?;
877
878    exprs
879        .into_iter()
880        .map(|expr| {
881            expr.transform_up(&|e| match e {
882                // if the relation is None, it's a column referencing one of the child plans
883                // if the relation is Some, it's a column of a table (most likely) and can be ignored since it's a leaf node
884                // (technically it can also refer to an aliased subquery)
885                Expr::Column(col) => {
886                    // Figure out which child the column belongs to
887                    let col = {
888                        let col = if let LogicalPlan::SubqueryAlias(sa) = parent {
889                            // If the parent is an aliased subquery, with the alias 'x',
890                            // any expressions of the form `x.column1`
891                            // refer to `column` in the input
892                            if col.relation.as_ref() == Some(&sa.alias) {
893                                Column::new_unqualified(col.name)
894                            } else {
895                                // All other columns are assumed to be leaf nodes (direct references to tables)
896                                return Ok(Transformed::no(Expr::Column(col)));
897                            }
898                        } else {
899                            col
900                        };
901
902                        col.normalize_with_schemas_and_ambiguity_check(&[&schemas], &using_columns)?
903                    };
904
905                    // first schema that matches column
906                    // the check from earlier ensures that this will always be Some
907                    // and that there should be only one schema that matches
908                    // (except if it is a USING column, in which case we can pick any match equivalently)
909                    let (child_idx, expr_idx) = schemas
910                        .iter()
911                        .enumerate()
912                        .find_map(|(schema_idx, schema)| {
913                            Some(schema_idx).zip(schema.maybe_index_of_column(&col))
914                        })
915                        .unwrap();
916
917                    Ok(Transformed::yes(
918                        output_exprs_by_child[child_idx][expr_idx].clone(),
919                    ))
920                }
921                _ => Ok(Transformed::no(e)),
922            })
923            .data()
924        })
925        .collect()
926}
927
928/// Return the columns output by this [`TableScan`].
929fn get_table_scan_columns(scan: &TableScan) -> Result<Vec<Column>> {
930    let fields = {
931        let mut schema = scan.source.schema().as_ref().clone();
932        if let Some(ref p) = scan.projection {
933            schema = schema.project(p)?;
934        }
935        schema.fields
936    };
937
938    Ok(fields
939        .into_iter()
940        .map(|field| Column::new(Some(scan.table_name.to_owned()), field.name()))
941        .collect())
942}
943
944#[cfg(test)]
945mod test {
946    use arrow::compute::concat_batches;
947    use datafusion::{datasource::provider_as_source, prelude::SessionContext};
948    use datafusion_common::{DataFusionError, Result};
949    use datafusion_sql::TableReference;
950
951    use super::SpjNormalForm;
952
953    async fn setup() -> Result<SessionContext> {
954        let ctx = SessionContext::new();
955
956        ctx.sql(
957            "CREATE TABLE t1 AS VALUES 
958            ('2021', 3, 'A'),
959            ('2022', 4, 'B'),
960            ('2023', 5, 'C')",
961        )
962        .await
963        .map_err(|e| e.context("parse `t1` table ddl"))?
964        .collect()
965        .await?;
966
967        ctx.sql(
968            "CREATE TABLE example (
969                l_orderkey INT,
970                l_partkey INT,
971                l_shipdate DATE,
972                l_quantity DOUBLE,
973                l_extendedprice DOUBLE,
974                o_custkey INT,
975                o_orderkey INT,
976                o_orderdate DATE,
977                p_name VARCHAR,
978                p_partkey INT
979            )
980        ",
981        )
982        .await
983        .map_err(|e| e.context("parse `example` table ddl"))?
984        .collect()
985        .await?;
986
987        Ok(ctx)
988    }
989
990    struct TestCase {
991        name: &'static str,
992        base: &'static str,
993        query: &'static str,
994    }
995
996    async fn run_test(case: &TestCase) -> Result<()> {
997        let context = setup()
998            .await
999            .map_err(|e| e.context("setup test environment"))?;
1000
1001        let base_plan = context.sql(case.base).await?.into_optimized_plan()?; // Optimize plan to eliminate unnormalized wildcard exprs
1002        let base_normal_form = SpjNormalForm::new(&base_plan)?;
1003
1004        context
1005            .sql(&format!("CREATE TABLE mv AS {}", case.base))
1006            .await?
1007            .collect()
1008            .await?;
1009
1010        let query_plan = context.sql(case.query).await?.into_optimized_plan()?;
1011        let query_normal_form = SpjNormalForm::new(&query_plan)?;
1012
1013        let table_ref = TableReference::bare("mv");
1014        let rewritten = query_normal_form
1015            .rewrite_from(
1016                &base_normal_form,
1017                table_ref.clone(),
1018                provider_as_source(context.table_provider(table_ref).await?),
1019            )?
1020            .ok_or(DataFusionError::Internal(
1021                "expected rewrite to succeed".to_string(),
1022            ))?;
1023
1024        assert_eq!(rewritten.schema().as_ref(), query_plan.schema().as_ref());
1025
1026        for plan in [&base_plan, &query_plan, &rewritten] {
1027            context
1028                .execute_logical_plan(plan.clone())
1029                .await?
1030                .explain(false, false)?
1031                .show()
1032                .await?;
1033        }
1034
1035        let expected = concat_batches(
1036            &query_plan.schema().as_ref().clone().into(),
1037            &context
1038                .execute_logical_plan(query_plan)
1039                .await?
1040                .collect()
1041                .await?,
1042        )?;
1043
1044        let result = concat_batches(
1045            &rewritten.schema().as_ref().clone().into(),
1046            &context
1047                .execute_logical_plan(rewritten)
1048                .await?
1049                .collect()
1050                .await?,
1051        )?;
1052
1053        assert_eq!(result, expected);
1054
1055        Ok(())
1056    }
1057
1058    #[tokio::test]
1059    async fn test_rewrite() -> Result<()> {
1060        let _ = env_logger::builder().is_test(true).try_init();
1061        let cases = vec![
1062            TestCase {
1063                name: "simple selection",
1064                base: "SELECT * FROM t1",
1065                query: "SELECT column1, column2 FROM t1",
1066            },
1067            TestCase {
1068                name: "selection with equality predicate",
1069                base: "SELECT * FROM t1",
1070                query: "SELECT column1, column2 FROM t1 WHERE column1 = column3",
1071            },
1072            TestCase {
1073                name: "selection with range filter",
1074                base: "SELECT * FROM t1 WHERE column2 > 3",
1075                query: "SELECT column1, column2 FROM t1 WHERE column2 > 4",
1076            },
1077            TestCase {
1078                name: "nontrivial projection",
1079                base: "SELECT concat(column1, column2), column2 FROM t1",
1080                query: "SELECT concat(column1, column2) FROM t1",
1081            },
1082            TestCase {
1083                name: "range filter + equality predicate",
1084                base:
1085                    "SELECT column1, column2 FROM t1 WHERE column1 = column3 AND column1 >= '2022'",
1086                query:
1087                // Since column1 = column3 in the original view,
1088                // we are allowed to substitute column1 for column3 and vice versa.
1089                    "SELECT column2, column3 FROM t1 WHERE column1 = column3 AND column3 >= '2023'",
1090            },
1091            TestCase {
1092                name: "duplicate expressions (X-209)",
1093                base: "SELECT * FROM t1",
1094                query:
1095                    "SELECT column1, NULL AS column2, NULL AS column3, column3 AS column4 FROM t1",
1096            },
1097            TestCase {
1098                name: "example from paper",
1099                base: "\
1100                SELECT 
1101                    l_orderkey, 
1102                    o_custkey, 
1103                    l_partkey,
1104                    l_shipdate, o_orderdate,
1105                    l_quantity*l_extendedprice AS gross_revenue
1106                FROM example
1107                WHERE
1108                    l_orderkey = o_orderkey AND 
1109                    l_partkey = p_partkey AND 
1110                    p_partkey >= 150 AND 
1111                    o_custkey >= 50 AND 
1112                    o_custkey <= 500 AND 
1113                    p_name LIKE '%abc%'
1114                ",
1115                query: "SELECT 
1116                    l_orderkey, 
1117                    o_custkey, 
1118                    l_partkey,
1119                    l_quantity*l_extendedprice
1120                FROM example
1121                WHERE 
1122                    l_orderkey = o_orderkey AND
1123                    l_partkey = p_partkey AND
1124                    l_partkey >= 150 AND 
1125                    l_partkey <= 160 AND
1126                    o_custkey = 123 AND
1127                    o_orderdate = l_shipdate AND
1128                    p_name like '%abc%' AND
1129                    l_quantity*l_extendedprice > 100
1130                ",
1131            },
1132        ];
1133
1134        for case in cases {
1135            println!("executing test: {}", case.name);
1136            run_test(&case).await.map_err(|e| e.context(case.name))?;
1137        }
1138
1139        Ok(())
1140    }
1141}