datafusion_materialized_views/rewrite/
normal_form.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18/*!
19
20This module contains code primarily used for view matching. We implement the view matching algorithm from [this paper](https://dsg.uwaterloo.ca/seminars/notes/larson-paper.pdf),
21which provides a method for determining when one Select-Project-Join query can be rewritten in terms of another Select-Project-Join query.
22
23The implementation is contained in [`SpjNormalForm::rewrite_from`]. The method can be summarized as follows:
241. Compute column equivalence classes for the query and the view.
252. Compute range intervals for the query and the view.
263. (Equijoin subsumption test) Check that each column equivalence class of the view is a subset of a column equivalence class of the query.
274. (Range subsumption test) Check that each range of the view contains the corresponding range from the query.
285. (Residual subsumption test) Check that every filter in the view that is not a column equivalence relation or a range filter matches a filter from the query.
296. Compute any compensating filters needed in order to restrict the view's rows to match the query.
307. Check that the output of the query, and the compensating filters, can be rewritten using the view's columns as inputs.
31
32# Example
33
34Consider the following table:
35
36```sql
37CREATE TABLE example (
38    l_orderkey INT,
39    l_partkey INT,
40    l_shipdate DATE,
41    l_quantity DOUBLE,
42    l_extendedprice DOUBLE,
43    o_custkey INT,
44    o_orderkey INT,
45    o_orderdate DATE,
46    p_name VARCHAR,
47    p_partkey INT,
48)
49```
50
51And consider the follow view:
52
53```sql
54CREATE VIEW mv AS SELECT
55    l_orderkey,
56    o_custkey,
57    l_partkey,
58    l_shipdate, o_orderdate,
59    l_quantity*l_extendedprice AS gross_revenue
60FROM example
61WHERE
62    l_orderkey = o_orderkey AND
63    l_partkey = p_partkey AND
64    p_partkey >= 150 AND
65    o_custkey >= 50 AND
66    o_custkey <= 500 AND
67    p_name LIKE '%abc%'
68```
69
70During analysis, we look at the implied equivalence classes and possible range of values for each equivalence class.
71For this view, the following nontrivial equivalence classes are generated:
72 * `{l_orderkey, o_orderkey}`
73 * `{l_partkey, p_partkey}`
74
75Note that all other columns have their own singleton equivalence classes, but are not shown here.
76Likewise, the following nontrivial ranges are generated:
77 * `150 <= {l_partkey, p_partkey} < inf`
78 * `50 <= {o_custkey} <= 500`
79
80The rest of the equivalence classes are considered to have ranges of (-inf, inf).
81The remaining filter `p_name LIKE '%abc%'` is considered 'residual' as it is not a column equivalence nor a range filter.
82
83Now consider the following query, which we will rewrite to use the view:
84
85```sql
86SELECT
87    l_orderkey,
88    o_custkey,
89    l_partkey,
90    l_quantity*l_extendedprice
91FROM example
92WHERE
93    l_orderkey = o_orderkey AND
94    l_partkey = p_partkey AND
95    l_partkey >= 150 AND
96    l_partkey <= 160 AND
97    o_custkey = 123 AND
98    o_orderdate = l_shipdate AND
99    p_name like '%abc%' AND
100    l_quantity*l_extendedprice > 100
101````
102
103This generates the following equivalence classes:
104 * `{l_orderkey, o_orderkey}`
105 * `{l_partkey, p_partkey}`
106 * `{o_orderdate, l_shipdate}`
107
108And the following ranges:
109 * `150 <= {l_partkey, p_partkey} <= 160`
110 * `123 <= {o_custkey} <= 123`
111
112As before, we still have the residual filter `p_name LIKE '%abc'`. However, note that `l_quantity*l_extendedprice > 100` is also
113a residual filter, as it is not a range filter on a column -- it's a range filter on a mathematical expression.
114
115We perform the three subsumption tests:
116 * Equijoin subsumption test:
117   * View equivalence classes: `{l_orderkey, o_orderkey}, {l_partkey, p_partkey}`
118   * Query equivalence classes: `{l_orderkey, o_orderkey}, {l_partkey, p_partkey}, {o_orderdate, l_shipdate}`
119   * Every view equivalence class is a subset of one from the query, so the test is passed.
120   * We generate the compensating filter `o_orderdate = l_shipdate`.
121 * Range subsumption test:
122   * View ranges:
123     * `150 <= {l_partkey, p_partkey} < inf`
124     * `50 <= {o_custkey} <= 500`
125   * Query ranges:
126     * `150 <= {l_partkey, p_partkey} <= 160`
127     * `123 <= {o_custkey} <= 123`
128   * Both of the view's ranges contain corresponding ranges from the query, therefore the test is passed.
129   * Since they're both strict inclusions, we include them both as compensating filters.
130 * Residual subsumption test:
131   * View residuals: `p_name LIKE '%abc'`
132   * Query residuals: `p_name LIKE '%abc'`, `l_quantity*l_extendedprice > 100`
133   * Every view residual has a matching residual from the query, and the test is passed.
134   * The leftover residual in the query, `l_quantity*l_extendedprice > 100`, is included as a compensating filter.
135
136Ultimately we have the following compensating filters:
137 * `o_orderdate = l_shipdate`
138 * `150 <= {l_partkey, p_partkey} <= 160`
139 * `123 <= {o_custkey} <= 123`
140 * `l_quantity*l_extendedprice > 100`
141
142The final check is to ensure that the output of our query can be computed from the view. This includes
143any expressions used in the compensating filters.
144This is a relatively simple check that mostly involves rewriting expressions to use columns from the view,
145and ensuring that no references to the original tables are left.
146
147This example is included as a unit test. After rewriting the query to use the view, the resulting plan looks like this:
148
149```text
150+---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
151| plan_type     | plan                                                                                                                                                                                                     |
152+---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
153| logical_plan  | Projection: mv.l_orderkey AS l_orderkey, mv.o_custkey AS o_custkey, mv.l_partkey AS l_partkey, mv.gross_revenue AS example.l_quantity * example.l_extendedprice                                          |
154|               |   Filter: mv.o_orderdate = mv.l_shipdate AND mv.l_partkey >= Int32(150) AND mv.l_partkey <= Int32(160) AND mv.o_custkey >= Int32(123) AND mv.o_custkey <= Int32(123) AND mv.gross_revenue > Float64(100) |
155|               |     TableScan: mv projection=[l_orderkey, o_custkey, l_partkey, l_shipdate, o_orderdate, gross_revenue]                                                                                                  |
156| physical_plan | ProjectionExec: expr=[l_orderkey@0 as l_orderkey, o_custkey@1 as o_custkey, l_partkey@2 as l_partkey, gross_revenue@5 as example.l_quantity * example.l_extendedprice]                                   |
157|               |   CoalesceBatchesExec: target_batch_size=8192                                                                                                                                                            |
158|               |     FilterExec: o_orderdate@4 = l_shipdate@3 AND l_partkey@2 >= 150 AND l_partkey@2 <= 160 AND o_custkey@1 >= 123 AND o_custkey@1 <= 123 AND gross_revenue@5 > 100                                       |
159|               |       MemoryExec: partitions=16, partition_sizes=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]                                                                                                        |
160|               |                                                                                                                                                                                                          |
161+---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
162```
163
164As one can see, all compensating filters are included, and the query only uses the view.
165
166*/
167
168use std::{
169    collections::{BTreeSet, HashMap, HashSet},
170    sync::Arc,
171};
172
173use datafusion_common::{
174    tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter},
175    Column, DFSchema, DataFusionError, ExprSchema, Result, ScalarValue, TableReference,
176};
177use datafusion_expr::{
178    interval_arithmetic::{satisfy_greater, Interval},
179    lit,
180    utils::split_conjunction,
181    BinaryExpr, Expr, LogicalPlan, LogicalPlanBuilder, Operator, TableScan, TableSource,
182};
183use itertools::Itertools;
184
185/// A normalized representation of a plan containing only Select/Project/Join in the relational algebra sense.
186/// In DataFusion terminology this also includes Filter nodes.
187/// Joins are not currently supported, but are planned.
188#[derive(Debug, Clone)]
189pub struct SpjNormalForm {
190    output_schema: Arc<DFSchema>,
191    output_exprs: Vec<Expr>,
192    referenced_tables: Vec<TableReference>,
193    predicate: Predicate,
194}
195
196/// Rewrite an expression to re-use output columns from this plan, where possible.
197impl TreeNodeRewriter for &SpjNormalForm {
198    type Node = Expr;
199
200    fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
201        Ok(match self.output_exprs.iter().position(|x| x == &node) {
202            Some(idx) => Transformed::yes(Expr::Column(Column::new_unqualified(
203                self.output_schema.field(idx).name().clone(),
204            ))),
205            None => Transformed::no(node),
206        })
207    }
208}
209
210impl SpjNormalForm {
211    /// Schema of data output by this plan.
212    pub fn output_schema(&self) -> &Arc<DFSchema> {
213        &self.output_schema
214    }
215
216    /// Expressions output by this plan.
217    /// These expressions can be used to rewrite this plan as a cross join followed by a projection;
218    /// however, this does not include any filters in the original plan, so the result will be a superset.
219    pub fn output_exprs(&self) -> &[Expr] {
220        &self.output_exprs
221    }
222
223    /// All tables referenced in this plan.
224    pub fn referenced_tables(&self) -> &[TableReference] {
225        &self.referenced_tables
226    }
227
228    /// Analyze an existing `LogicalPlan` and rewrite it in select-project-join normal form.
229    pub fn new(original_plan: &LogicalPlan) -> Result<Self> {
230        let predicate = Predicate::new(original_plan)?;
231        let output_exprs = get_output_exprs(original_plan)?
232            .into_iter()
233            .map(|expr| predicate.normalize_expr(expr))
234            .collect();
235
236        let mut referenced_tables = vec![];
237        original_plan
238            .apply(|plan| {
239                if let LogicalPlan::TableScan(scan) = plan {
240                    referenced_tables.push(scan.table_name.clone());
241                }
242
243                Ok(TreeNodeRecursion::Continue)
244            })
245            // No chance of error since we never return Err -- this unwrap is safe
246            .unwrap();
247
248        Ok(Self {
249            output_schema: Arc::clone(original_plan.schema()),
250            output_exprs,
251            referenced_tables,
252            predicate,
253        })
254    }
255
256    /// Rewrite this plan as as selection/projection on top of another plan,
257    /// which we use `qualifier` to refer to.
258    /// This is useful for rewriting queries to use materialized views.
259    pub fn rewrite_from(
260        &self,
261        mut other: &Self,
262        qualifier: TableReference,
263        source: Arc<dyn TableSource>,
264    ) -> Result<Option<LogicalPlan>> {
265        log::trace!("rewriting from {qualifier}");
266        let mut new_output_exprs = Vec::with_capacity(self.output_exprs.len());
267        // check that our output exprs are sub-expressions of the other one's output exprs
268        for (i, output_expr) in self.output_exprs.iter().enumerate() {
269            let new_output_expr = other
270                .predicate
271                .normalize_expr(output_expr.clone())
272                .rewrite(&mut other)?
273                .data;
274
275            // Check that all references to the original tables have been replaced.
276            // All remaining column expressions should be unqualified, which indicates
277            // that they refer to the output of the sub-plan (in this case the view)
278            if new_output_expr
279                .column_refs()
280                .iter()
281                .any(|c| c.relation.is_some())
282            {
283                return Ok(None);
284            }
285
286            let column = &self.output_schema.columns()[i];
287            new_output_exprs.push(
288                new_output_expr.alias_qualified(column.relation.clone(), column.name.clone()),
289            );
290        }
291
292        log::trace!("passed output rewrite");
293
294        // Check the subsumption tests, and compute any auxiliary needed filter expressions.
295        // If we pass all three subsumption tests, this plan's output is a subset of the other
296        // plan's output.
297        let ((eq_filters, range_filters), residual_filters) = match self
298            .predicate
299            .equijoin_subsumption_test(&other.predicate)
300            .zip(self.predicate.range_subsumption_test(&other.predicate)?)
301            .zip(self.predicate.residual_subsumption_test(&other.predicate))
302        {
303            None => return Ok(None),
304            Some(filters) => filters,
305        };
306
307        log::trace!("passed subsumption tests");
308
309        let all_filters = eq_filters
310            .into_iter()
311            .chain(range_filters)
312            .chain(residual_filters)
313            .map(|expr| expr.rewrite(&mut other).unwrap().data)
314            .reduce(|a, b| a.and(b));
315
316        if all_filters
317            .as_ref()
318            .map(|expr| expr.column_refs())
319            .is_some_and(|columns| columns.iter().any(|c| c.relation.is_some()))
320        {
321            return Ok(None);
322        }
323
324        let mut builder = LogicalPlanBuilder::scan(qualifier, source, None)?;
325
326        if let Some(filter) = all_filters {
327            builder = builder.filter(filter)?;
328        }
329
330        builder.project(new_output_exprs)?.build().map(Some)
331    }
332}
333
334/// Stores information on filters from a Select-Project-Join plan.
335#[derive(Debug, Clone)]
336struct Predicate {
337    /// Full table schema, including all possible columns.
338    schema: DFSchema,
339    /// List of column equivalence classes.
340    eq_classes: Vec<ColumnEquivalenceClass>,
341    /// Reverse lookup by eq class elements
342    eq_class_idx_by_column: HashMap<Column, usize>,
343    /// Stores (possibly empty) intervals describing each equivalence class.
344    ranges_by_equivalence_class: Vec<Option<Interval>>,
345    /// Filter expressions that aren't column equality predicates or range filters.
346    residuals: HashSet<Expr>,
347}
348
349impl Predicate {
350    fn new(plan: &LogicalPlan) -> Result<Self> {
351        let mut schema = DFSchema::empty();
352        plan.apply(|plan| {
353            if let LogicalPlan::TableScan(scan) = plan {
354                let new_schema = DFSchema::try_from_qualified_schema(
355                    scan.table_name.clone(),
356                    scan.source.schema().as_ref(),
357                )?;
358                schema = if schema.fields().is_empty() {
359                    new_schema
360                } else {
361                    schema.join(&new_schema)?
362                }
363            }
364
365            Ok(TreeNodeRecursion::Continue)
366        })?;
367
368        let mut new = Self {
369            schema,
370            eq_classes: vec![],
371            eq_class_idx_by_column: HashMap::default(),
372            ranges_by_equivalence_class: vec![],
373            residuals: HashSet::new(),
374        };
375
376        // Collect all referenced columns
377        plan.apply(|plan| {
378            if let LogicalPlan::TableScan(scan) = plan {
379                for (i, (table_ref, field)) in DFSchema::try_from_qualified_schema(
380                    scan.table_name.clone(),
381                    scan.source.schema().as_ref(),
382                )?
383                .iter()
384                .enumerate()
385                {
386                    let column = Column::new(table_ref.cloned(), field.name());
387                    let data_type = field.data_type();
388                    new.eq_classes
389                        .push(ColumnEquivalenceClass::new_singleton(column.clone()));
390                    new.eq_class_idx_by_column.insert(column, i);
391                    new.ranges_by_equivalence_class
392                        .push(Some(Interval::make_unbounded(data_type)?));
393                }
394            }
395
396            Ok(TreeNodeRecursion::Continue)
397        })?;
398
399        // Collect any filters
400        plan.apply(|plan| {
401            let filters = match plan {
402                LogicalPlan::TableScan(scan) => scan.filters.as_slice(),
403                LogicalPlan::Filter(filter) => core::slice::from_ref(&filter.predicate),
404                LogicalPlan::Join(_join) => {
405                    return Err(DataFusionError::Internal(
406                        "joins are not supported yet".to_string(),
407                    ))
408                }
409                LogicalPlan::Projection(_) => &[],
410                _ => {
411                    return Err(DataFusionError::Plan(format!(
412                        "unsupported logical plan: {}",
413                        plan.display()
414                    )))
415                }
416            };
417
418            for expr in filters.iter().flat_map(split_conjunction) {
419                new.insert_conjuct(expr)?;
420            }
421
422            Ok(TreeNodeRecursion::Continue)
423        })?;
424
425        Ok(new)
426    }
427
428    fn class_for_column(&self, col: &Column) -> Option<&ColumnEquivalenceClass> {
429        self.eq_class_idx_by_column
430            .get(col)
431            .and_then(|&idx| self.eq_classes.get(idx))
432    }
433
434    /// Add a new column equivalence
435    fn add_equivalence(&mut self, c1: &Column, c2: &Column) -> Result<()> {
436        match (
437            self.eq_class_idx_by_column.get(c1),
438            self.eq_class_idx_by_column.get(c2),
439        ) {
440            (None, None) => {
441                // Make a new eq class [c1, c2]
442                self.eq_classes
443                    .push(ColumnEquivalenceClass::new([c1.clone(), c2.clone()]));
444                self.ranges_by_equivalence_class
445                    .push(Some(Interval::make_unbounded(
446                        self.schema.field_from_column(c1).unwrap().data_type(),
447                    )?));
448            }
449
450            // These two cases are just adding a column to an existing class
451            (None, Some(&idx)) => {
452                self.eq_classes[idx].columns.insert(c1.clone());
453            }
454            (Some(&idx), None) => {
455                self.eq_classes[idx].columns.insert(c2.clone());
456            }
457            (Some(&i), Some(&j)) => {
458                if i == j {
459                    // The two columns are already in the same equivalence class.
460                    return Ok(());
461                }
462                // We need to merge two existing column eq classes.
463
464                // Delete the eq class with a larger index,
465                // so that the other one has its position preserved.
466                // Not necessary, but it's just a little simpler this way
467                let (i, j) = if i < j { (i, j) } else { (j, i) };
468
469                // Merge the deleted eq class with its new partner
470                let new_columns = self.eq_classes.remove(j).columns;
471                self.eq_classes[i].columns.extend(new_columns.clone());
472                for column in new_columns {
473                    self.eq_class_idx_by_column.insert(column, i);
474                }
475                // update all moved entries
476                for idx in self.eq_class_idx_by_column.values_mut() {
477                    if *idx > j {
478                        *idx -= 1;
479                    }
480                }
481
482                // Merge ranges
483                // Now that we know the two equivalence classes are equal,
484                // the new range is the intersection of the existing two ranges.
485                self.ranges_by_equivalence_class[i] = self.ranges_by_equivalence_class[i]
486                    .clone()
487                    .zip(self.ranges_by_equivalence_class.remove(j))
488                    .and_then(|(range, other_range)| range.intersect(other_range).transpose())
489                    .transpose()?;
490            }
491        }
492
493        Ok(())
494    }
495
496    /// Update range for a column's equivalence class
497    fn add_range(&mut self, c: &Column, op: &Operator, value: &ScalarValue) -> Result<()> {
498        // first coerce the value if needed
499        let value = value.cast_to(self.schema.data_type(c)?)?;
500        let range = self
501            .eq_class_idx_by_column
502            .get(c)
503            .ok_or_else(|| {
504                DataFusionError::Plan(format!("column {c} not found in equivalence classes"))
505            })
506            .and_then(|&idx| {
507                self.ranges_by_equivalence_class
508                    .get_mut(idx)
509                    .ok_or_else(|| {
510                        DataFusionError::Plan(format!(
511                            "range not found class not found for column {c} with equivalence class {:?}", self.eq_classes.get(idx)
512                        ))
513                    })
514            })?;
515
516        let new_range = match op {
517            Operator::Eq => Interval::try_new(value.clone(), value.clone()),
518            Operator::LtEq => {
519                Interval::try_new(ScalarValue::try_from(value.data_type())?, value.clone())
520            }
521            Operator::GtEq => {
522                Interval::try_new(value.clone(), ScalarValue::try_from(value.data_type())?)
523            }
524
525            // Note: This is a roundabout way (read: hack) to construct an open Interval.
526            // DataFusion's Interval type represents closed intervals,
527            // so handling of open intervals is done by adding/subtracting the smallest increment.
528            // However, there is not really a public API to do this,
529            // other than the satisfy_greater method.
530            Operator::Lt => Ok(
531                match satisfy_greater(
532                    &Interval::try_new(value.clone(), value.clone())?,
533                    &Interval::make_unbounded(&value.data_type())?,
534                    true,
535                )? {
536                    Some((_, range)) => range,
537                    None => {
538                        *range = None;
539                        return Ok(());
540                    }
541                },
542            ),
543            // Same thing as above.
544            Operator::Gt => Ok(
545                match satisfy_greater(
546                    &Interval::make_unbounded(&value.data_type())?,
547                    &Interval::try_new(value.clone(), value.clone())?,
548                    true,
549                )? {
550                    Some((range, _)) => range,
551                    None => {
552                        *range = None;
553                        return Ok(());
554                    }
555                },
556            ),
557            _ => Err(DataFusionError::Plan(
558                "unsupported binary expression".to_string(),
559            )),
560        }?;
561
562        *range = match range {
563            None => Some(new_range),
564            Some(range) => range.intersect(new_range)?,
565        };
566
567        Ok(())
568    }
569
570    /// Add a generic filter expression to our collection of filters.
571    /// A conjunct is a term T_i of an expression T_1 AND T_2 AND T_3 AND ...
572    fn insert_conjuct(&mut self, expr: &Expr) -> Result<()> {
573        match expr {
574            Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
575                self.insert_binary_expr(left, *op, right)?;
576            }
577            Expr::Not(e) => match e.as_ref() {
578                Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
579                    if let Some(negated) = op.negate() {
580                        self.insert_binary_expr(left, negated, right)?;
581                    } else {
582                        self.residuals.insert(expr.clone());
583                    }
584                }
585                _ => {
586                    self.residuals.insert(expr.clone());
587                }
588            },
589            _ => {
590                self.residuals.insert(expr.clone());
591            }
592        }
593
594        Ok(())
595    }
596
597    /// Add a binary expression to our collection of filters.
598    fn insert_binary_expr(&mut self, left: &Expr, op: Operator, right: &Expr) -> Result<()> {
599        match (left, op, right) {
600            (Expr::Column(c), op, Expr::Literal(v, _)) => {
601                if let Err(e) = self.add_range(c, &op, v) {
602                    // Add a range can fail in some cases, so just fallthrough
603                    log::debug!("failed to add range filter: {e}");
604                } else {
605                    return Ok(());
606                }
607            }
608            (Expr::Literal(_, _), op, Expr::Column(_)) => {
609                if let Some(swapped) = op.swap() {
610                    return self.insert_binary_expr(right, swapped, left);
611                }
612            }
613            // update eq classes & merge ranges by eq class
614            (Expr::Column(c1), Operator::Eq, Expr::Column(c2)) => {
615                self.add_equivalence(c1, c2)?;
616                return Ok(());
617            }
618            _ => {}
619        }
620
621        self.residuals.insert(Expr::BinaryExpr(BinaryExpr {
622            left: Box::new(left.clone()),
623            op,
624            right: Box::new(right.clone()),
625        }));
626
627        Ok(())
628    }
629
630    /// Test that all column equivalence classes of `other` are subsumed by one from `self`.
631    /// This is called the 'equijoin' subsumption test because column equivalences often
632    /// result from join predicates.
633    /// Returns any compensating column equality predicates that should be applied to
634    /// make this plan match the output of the other one.
635    fn equijoin_subsumption_test(&self, other: &Self) -> Option<Vec<Expr>> {
636        let mut new_equivalences = vec![];
637        // check that all equivalence classes of `other` are contained in one from `self`
638        for other_class in &other.eq_classes {
639            let (representative, eq_class) = match other_class
640                .columns
641                .iter()
642                .find_map(|c| self.class_for_column(c).map(|class| (c, class)))
643            {
644                // We don't contain any columns from this eq class.
645                // Technically this is alright if the equivalence class is trivial,
646                // because we're allowed to be a subset of the other plan.
647                // If the equivalence class is nontrivial then we can't compute
648                // compensating filters because we lack the columns that would be
649                // used in the filter.
650                None if other_class.columns.len() == 1 => continue,
651                // We do contain columns from this eq class.
652                Some(tuple) => tuple,
653                // We don't contain columns from this eq class and the
654                // class is nontrivial.
655                _ => return None,
656            };
657
658            if !other_class.columns.is_subset(&eq_class.columns) {
659                return None;
660            }
661
662            for column in eq_class.columns.difference(&other_class.columns) {
663                new_equivalences
664                    .push(Expr::Column(representative.clone()).eq(Expr::Column(column.clone())))
665            }
666        }
667
668        log::trace!("passed equijoin subsumption test");
669
670        Some(new_equivalences)
671    }
672
673    /// Test that all range filters of `self` are contained in one from `other`.
674    /// This includes equality comparisons, which map to ranges of the form [v, v]
675    /// for some value v.
676    /// Returns any compensating range filters that should be applied to this plan
677    /// to make its output match the other one.
678    fn range_subsumption_test(&self, other: &Self) -> Result<Option<Vec<Expr>>> {
679        let mut extra_range_filters = vec![];
680        for (eq_class, range) in self
681            .eq_classes
682            .iter()
683            .zip(self.ranges_by_equivalence_class.iter())
684        {
685            let range = match range {
686                None => {
687                    // empty; it's always contained in another range
688                    // also this range is never satisfiable, so it's always False
689                    extra_range_filters.push(lit(false));
690                    continue;
691                }
692                Some(range) => range,
693            };
694
695            let (other_column, other_range) = match eq_class.columns.iter().find_map(|c| {
696                other.eq_class_idx_by_column.get(c).and_then(|&idx| {
697                    other.ranges_by_equivalence_class[idx]
698                        .as_ref()
699                        .map(|range| (other.eq_classes[idx].columns.first().unwrap(), range))
700                })
701            }) {
702                None => return Ok(None),
703                Some(range) => range,
704            };
705
706            if other_range.contains(range)? != Interval::CERTAINLY_TRUE {
707                return Ok(None);
708            }
709
710            if range.contains(other_range)? != Interval::CERTAINLY_TRUE {
711                if !(range.lower().is_null() || range.upper().is_null())
712                    && (range.lower().eq(range.upper()))
713                {
714                    // Certain datafusion code paths only work if eq expressions are preserved
715                    // that is, col >= val AND col <= val is not treated the same as col = val
716                    // We special-case this to make sure everything works as expected.
717                    // todo: could this be a logical optimizer?
718                    extra_range_filters.push(Expr::BinaryExpr(BinaryExpr {
719                        left: Box::new(Expr::Column(other_column.clone())),
720                        op: Operator::Eq,
721                        right: Box::new(Expr::Literal(range.lower().clone(), None)),
722                    }))
723                } else {
724                    if !range.lower().is_null() {
725                        extra_range_filters.push(Expr::BinaryExpr(BinaryExpr {
726                            left: Box::new(Expr::Column(other_column.clone())),
727                            op: Operator::GtEq,
728                            right: Box::new(Expr::Literal(range.lower().clone(), None)),
729                        }))
730                    }
731
732                    if !range.upper().is_null() {
733                        extra_range_filters.push(Expr::BinaryExpr(BinaryExpr {
734                            left: Box::new(Expr::Column(other_column.clone())),
735                            op: Operator::LtEq,
736                            right: Box::new(Expr::Literal(range.upper().clone(), None)),
737                        }))
738                    }
739                }
740            }
741        }
742
743        log::trace!("passed range subsumption test");
744
745        Ok(Some(extra_range_filters))
746    }
747
748    /// Test that any "residual" filters (not column equivalence or range filters) from
749    /// `other` have matching entries in `self`.
750    /// For example, a residual filter might look like `x * y > 100`, as this expression
751    /// is neither a column equivalence nor a range filter (importantly, not a range filter
752    /// directly on a column).)
753    /// This ensures that `self` is a subset of `other`.
754    /// Return any residual filters in this plan that are not in the other one.
755    fn residual_subsumption_test(&self, other: &Self) -> Option<Vec<Expr>> {
756        let [self_residuals, other_residuals] = [self.residuals.clone(), other.residuals.clone()]
757            .map(|set| {
758                set.into_iter()
759                    .map(|r| self.normalize_expr(r.clone()))
760                    .collect::<HashSet<Expr>>()
761            });
762
763        if !self_residuals.is_superset(&other_residuals) {
764            return None;
765        }
766
767        log::trace!("passed residual subsumption test");
768
769        Some(
770            self_residuals
771                .difference(&other.residuals)
772                .cloned()
773                .collect_vec(),
774        )
775    }
776
777    /// Rewrite all expressions in terms of their normal representatives
778    /// with respect to this predicate's equivalence classes.
779    fn normalize_expr(&self, e: Expr) -> Expr {
780        e.transform(&|e| {
781            let c = match e {
782                Expr::Column(c) => c,
783                Expr::Alias(alias) => return Ok(Transformed::yes(alias.expr.as_ref().clone())),
784                _ => return Ok(Transformed::no(e)),
785            };
786
787            if let Some(eq_class) = self.class_for_column(&c) {
788                Ok(Transformed::yes(Expr::Column(
789                    eq_class.columns.first().unwrap().clone(),
790                )))
791            } else {
792                Ok(Transformed::no(Expr::Column(c)))
793            }
794        })
795        .map(|t| t.data)
796        // No chance of error since we never return Err -- this unwrap is safe
797        .unwrap()
798    }
799}
800
801/// A collection of columns that are all considered to be equivalent.
802/// In some cases we normalize expressions so that they use the "normal" representative
803/// in place of any other columns in the class.
804/// This normal representative is chosen arbitrarily.
805#[derive(Debug, Clone, Default)]
806struct ColumnEquivalenceClass {
807    // first element is the normal representative of the equivalence class
808    columns: BTreeSet<Column>,
809}
810
811impl ColumnEquivalenceClass {
812    fn new(columns: impl IntoIterator<Item = Column>) -> Self {
813        Self {
814            columns: BTreeSet::from_iter(columns),
815        }
816    }
817
818    fn new_singleton(column: Column) -> Self {
819        Self {
820            columns: BTreeSet::from([column]),
821        }
822    }
823}
824
825/// For each field in the plan's schema, get an expression that represents the field's definition.
826/// Furthermore, normalize all expressions so that the only column expressions refer to directly to tables,
827/// not alias subqueries or child plans.
828///
829/// This essentially is equivalent to rewriting the query as a projection against a cross join.
830fn get_output_exprs(plan: &LogicalPlan) -> Result<Vec<Expr>> {
831    use datafusion_expr::logical_plan::*;
832
833    let output_exprs = match plan {
834        // ignore filter, sort, and limit
835        // they don't change the schema or the definitions
836        LogicalPlan::Filter(_)
837        | LogicalPlan::Sort(_)
838        | LogicalPlan::Limit(_)
839        | LogicalPlan::Distinct(_) => return get_output_exprs(plan.inputs()[0]),
840        LogicalPlan::Projection(Projection { expr, .. }) => Ok(expr.clone()),
841        LogicalPlan::Aggregate(Aggregate {
842            group_expr,
843            aggr_expr,
844            ..
845        }) => Ok(Vec::from_iter(
846            group_expr.iter().chain(aggr_expr.iter()).cloned(),
847        )),
848        LogicalPlan::Window(Window {
849            input, window_expr, ..
850        }) => Ok(Vec::from_iter(
851            input
852                .schema()
853                .fields()
854                .iter()
855                .map(|field| Expr::Column(Column::new_unqualified(field.name())))
856                .chain(window_expr.iter().cloned()),
857        )),
858        // if it's a table scan, just exit early with explicit return
859        LogicalPlan::TableScan(table_scan) => {
860            return Ok(get_table_scan_columns(table_scan)?
861                .into_iter()
862                .map(Expr::Column)
863                .collect())
864        }
865        LogicalPlan::Unnest(unnest) => Ok(unnest
866            .schema
867            .columns()
868            .into_iter()
869            .map(Expr::Column)
870            .collect()),
871        LogicalPlan::Join(join) => Ok(join
872            .left
873            .schema()
874            .columns()
875            .into_iter()
876            .chain(join.right.schema().columns())
877            .map(Expr::Column)
878            .collect_vec()),
879        LogicalPlan::SubqueryAlias(sa) => return get_output_exprs(&sa.input),
880        _ => Err(DataFusionError::NotImplemented(format!(
881            "Logical plan not supported: {}",
882            plan.display()
883        ))),
884    }?;
885
886    flatten_exprs(output_exprs, plan)
887}
888
889/// Recursively normalize expressions so that any columns refer directly to tables and not subqueries.
890fn flatten_exprs(exprs: Vec<Expr>, parent: &LogicalPlan) -> Result<Vec<Expr>> {
891    if matches!(parent, LogicalPlan::TableScan(_)) {
892        return Ok(exprs);
893    }
894
895    let schemas = parent
896        .inputs()
897        .iter()
898        .map(|input| input.schema().as_ref())
899        .collect_vec();
900    let using_columns = parent.using_columns()?;
901
902    let output_exprs_by_child = parent
903        .inputs()
904        .into_iter()
905        .map(get_output_exprs)
906        .collect::<Result<Vec<_>>>()?;
907
908    exprs
909        .into_iter()
910        .map(|expr| {
911            expr.transform_up(&|e| match e {
912                // if the relation is None, it's a column referencing one of the child plans
913                // if the relation is Some, it's a column of a table (most likely) and can be ignored since it's a leaf node
914                // (technically it can also refer to an aliased subquery)
915                Expr::Column(col) => {
916                    // Figure out which child the column belongs to
917                    let col = {
918                        let col = if let LogicalPlan::SubqueryAlias(sa) = parent {
919                            // If the parent is an aliased subquery, with the alias 'x',
920                            // any expressions of the form `x.column1`
921                            // refer to `column` in the input
922                            if col.relation.as_ref() == Some(&sa.alias) {
923                                Column::new_unqualified(col.name)
924                            } else {
925                                // All other columns are assumed to be leaf nodes (direct references to tables)
926                                return Ok(Transformed::no(Expr::Column(col)));
927                            }
928                        } else {
929                            col
930                        };
931
932                        col.normalize_with_schemas_and_ambiguity_check(&[&schemas], &using_columns)?
933                    };
934
935                    // first schema that matches column
936                    // the check from earlier ensures that this will always be Some
937                    // and that there should be only one schema that matches
938                    // (except if it is a USING column, in which case we can pick any match equivalently)
939                    let (child_idx, expr_idx) = schemas
940                        .iter()
941                        .enumerate()
942                        .find_map(|(schema_idx, schema)| {
943                            Some(schema_idx).zip(schema.maybe_index_of_column(&col))
944                        })
945                        .unwrap();
946
947                    Ok(Transformed::yes(
948                        output_exprs_by_child[child_idx][expr_idx].clone(),
949                    ))
950                }
951                _ => Ok(Transformed::no(e)),
952            })
953            .data()
954        })
955        .collect()
956}
957
958/// Return the columns output by this [`TableScan`].
959fn get_table_scan_columns(scan: &TableScan) -> Result<Vec<Column>> {
960    let fields = {
961        let mut schema = scan.source.schema().as_ref().clone();
962        if let Some(ref p) = scan.projection {
963            schema = schema.project(p)?;
964        }
965        schema.fields
966    };
967
968    Ok(fields
969        .into_iter()
970        .map(|field| Column::new(Some(scan.table_name.to_owned()), field.name()))
971        .collect())
972}
973
974#[cfg(test)]
975mod test {
976    use arrow::compute::concat_batches;
977    use datafusion::{
978        datasource::provider_as_source,
979        prelude::{SessionConfig, SessionContext},
980    };
981    use datafusion_common::{DataFusionError, Result};
982    use datafusion_sql::TableReference;
983    use tempfile::tempdir;
984
985    use super::SpjNormalForm;
986
987    async fn setup() -> Result<SessionContext> {
988        let ctx = SessionContext::new_with_config(
989            SessionConfig::new()
990                .set_bool("datafusion.execution.parquet.pushdown_filters", true)
991                .set_bool("datafusion.explain.logical_plan_only", true),
992        );
993
994        let t1_path = tempdir()?;
995
996        // Create external table to exercise parquet filter pushdown.
997        // This will put the filters directly inside the `TableScan` node.
998        // This is important because `TableScan` can have filters on
999        // columns not in its own output.
1000        ctx.sql(&format!(
1001            "
1002                CREATE EXTERNAL TABLE t1 (
1003                    column1 VARCHAR,
1004                    column2 BIGINT,
1005                    column3 CHAR
1006                )
1007                STORED AS PARQUET
1008                LOCATION '{}'",
1009            t1_path.path().to_string_lossy()
1010        ))
1011        .await
1012        .map_err(|e| e.context("setup `t1` table"))?
1013        .collect()
1014        .await?;
1015
1016        ctx.sql(
1017            "INSERT INTO t1 VALUES
1018            ('2021', 3, 'A'),
1019            ('2022', 4, 'B'),
1020            ('2023', 5, 'C')",
1021        )
1022        .await
1023        .map_err(|e| e.context("parse `t1` table ddl"))?
1024        .collect()
1025        .await?;
1026
1027        ctx.sql(
1028            "CREATE TABLE example (
1029                l_orderkey INT,
1030                l_partkey INT,
1031                l_shipdate DATE,
1032                l_quantity DOUBLE,
1033                l_extendedprice DOUBLE,
1034                o_custkey INT,
1035                o_orderkey INT,
1036                o_orderdate DATE,
1037                p_name VARCHAR,
1038                p_partkey INT
1039            )",
1040        )
1041        .await
1042        .map_err(|e| e.context("parse `example` table ddl"))?
1043        .collect()
1044        .await?;
1045
1046        Ok(ctx)
1047    }
1048
1049    struct TestCase {
1050        name: &'static str,
1051        base: &'static str,
1052        query: &'static str,
1053    }
1054
1055    async fn run_test(case: &TestCase) -> Result<()> {
1056        let context = setup()
1057            .await
1058            .map_err(|e| e.context("setup test environment"))?;
1059
1060        let base_plan = context.sql(case.base).await?.into_optimized_plan()?; // Optimize plan to eliminate unnormalized wildcard exprs
1061        let base_normal_form = SpjNormalForm::new(&base_plan)?;
1062
1063        context
1064            .sql(&format!("CREATE TABLE mv AS {}", case.base))
1065            .await?
1066            .collect()
1067            .await?;
1068
1069        let query_plan = context.sql(case.query).await?.into_optimized_plan()?;
1070        let query_normal_form = SpjNormalForm::new(&query_plan)?;
1071
1072        for plan in [&base_plan, &query_plan] {
1073            context
1074                .execute_logical_plan(plan.clone())
1075                .await?
1076                .explain(false, false)?
1077                .show()
1078                .await?;
1079        }
1080
1081        let table_ref = TableReference::bare("mv");
1082        let rewritten = query_normal_form
1083            .rewrite_from(
1084                &base_normal_form,
1085                table_ref.clone(),
1086                provider_as_source(context.table_provider(table_ref).await?),
1087            )?
1088            .ok_or(DataFusionError::Internal(
1089                "expected rewrite to succeed".to_string(),
1090            ))?;
1091
1092        context
1093            .execute_logical_plan(rewritten.clone())
1094            .await?
1095            .explain(false, false)?
1096            .show()
1097            .await?;
1098
1099        assert_eq!(rewritten.schema().as_ref(), query_plan.schema().as_ref());
1100
1101        let expected = concat_batches(
1102            &query_plan.schema().as_ref().clone().into(),
1103            &context
1104                .execute_logical_plan(query_plan)
1105                .await?
1106                .collect()
1107                .await?,
1108        )?;
1109
1110        let result = concat_batches(
1111            &rewritten.schema().as_ref().clone().into(),
1112            &context
1113                .execute_logical_plan(rewritten)
1114                .await?
1115                .collect()
1116                .await?,
1117        )?;
1118
1119        assert_eq!(result, expected);
1120
1121        Ok(())
1122    }
1123
1124    #[tokio::test]
1125    async fn test_rewrite() -> Result<()> {
1126        let _ = env_logger::builder().is_test(true).try_init();
1127        let cases = vec![
1128            TestCase {
1129                name: "simple selection",
1130                base: "SELECT * FROM t1",
1131                query: "SELECT column1, column2 FROM t1",
1132            },
1133            TestCase {
1134                name: "selection with equality predicate",
1135                base: "SELECT * FROM t1",
1136                query: "SELECT column1, column2 FROM t1 WHERE column1 = column3",
1137            },
1138            TestCase {
1139                name: "selection with range filter",
1140                base: "SELECT * FROM t1 WHERE column2 > 3",
1141                query: "SELECT column1, column2 FROM t1 WHERE column2 > 4",
1142            },
1143            TestCase {
1144                name: "nontrivial projection",
1145                base: "SELECT concat(column1, column2), column2 FROM t1",
1146                query: "SELECT concat(column1, column2) FROM t1",
1147            },
1148            TestCase {
1149                name: "range filter + equality predicate",
1150                base:
1151                    "SELECT column1, column2 FROM t1 WHERE column1 = column3 AND column1 >= '2022'",
1152                query:
1153                // Since column1 = column3 in the original view,
1154                // we are allowed to substitute column1 for column3 and vice versa.
1155                    "SELECT column2, column3 FROM t1 WHERE column1 = column3 AND column3 >= '2023'",
1156            },
1157            TestCase {
1158                name: "duplicate expressions (X-209)",
1159                base: "SELECT * FROM t1",
1160                query:
1161                    "SELECT column1, NULL AS column2, NULL AS column3, column3 AS column4 FROM t1",
1162            },
1163            TestCase {
1164                name: "example from paper",
1165                base: "\
1166                SELECT
1167                    l_orderkey,
1168                    o_custkey,
1169                    l_partkey,
1170                    l_shipdate, o_orderdate,
1171                    l_quantity*l_extendedprice AS gross_revenue
1172                FROM example
1173                WHERE
1174                    l_orderkey = o_orderkey AND
1175                    l_partkey = p_partkey AND
1176                    p_partkey >= 150 AND
1177                    o_custkey >= 50 AND
1178                    o_custkey <= 500 AND
1179                    p_name LIKE '%abc%'
1180                ",
1181                query: "SELECT
1182                    l_orderkey,
1183                    o_custkey,
1184                    l_partkey,
1185                    l_quantity*l_extendedprice
1186                FROM example
1187                WHERE
1188                    l_orderkey = o_orderkey AND
1189                    l_partkey = p_partkey AND
1190                    l_partkey >= 150 AND
1191                    l_partkey <= 160 AND
1192                    o_custkey = 123 AND
1193                    o_orderdate = l_shipdate AND
1194                    p_name like '%abc%' AND
1195                    l_quantity*l_extendedprice > 100
1196                ",
1197            },
1198            TestCase {
1199                name: "naked table scan with pushed down filters",
1200                base: "SELECT column1 FROM t1 WHERE column2 <= 3",
1201                query: "SELECT FROM t1 WHERE column2 <= 3",
1202            },
1203        ];
1204
1205        for case in cases {
1206            println!("executing test: {}", case.name);
1207            run_test(&case).await.map_err(|e| e.context(case.name))?;
1208        }
1209
1210        Ok(())
1211    }
1212}