Skip to main content

datafusion_physical_optimizer/
limited_distinct_aggregation.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! A special-case optimizer rule that pushes limit into a grouped aggregation
19//! which has no aggregate expressions or sorting requirements
20
21use std::sync::Arc;
22
23use datafusion_physical_plan::aggregates::AggregateExec;
24use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
25use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties};
26
27use datafusion_common::Result;
28use datafusion_common::config::ConfigOptions;
29use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
30
31use crate::PhysicalOptimizerRule;
32use itertools::Itertools;
33
34/// An optimizer rule that passes a `limit` hint into grouped aggregations which don't require all
35/// rows in the group to be processed for correctness. Example queries fitting this description are:
36/// - `SELECT distinct l_orderkey FROM lineitem LIMIT 10;`
37/// - `SELECT l_orderkey FROM lineitem GROUP BY l_orderkey LIMIT 10;`
38#[derive(Debug)]
39pub struct LimitedDistinctAggregation {}
40
41impl LimitedDistinctAggregation {
42    /// Create a new `LimitedDistinctAggregation`
43    pub fn new() -> Self {
44        Self {}
45    }
46
47    fn transform_agg(
48        aggr: &AggregateExec,
49        limit: usize,
50    ) -> Option<Arc<dyn ExecutionPlan>> {
51        // rules for transforming this Aggregate are held in this method
52        if !aggr.is_unordered_unfiltered_group_by_distinct() {
53            return None;
54        }
55
56        // We found what we want: clone, copy the limit down, and return modified node
57        let new_aggr = AggregateExec::try_new(
58            *aggr.mode(),
59            aggr.group_expr().clone(),
60            aggr.aggr_expr().to_vec(),
61            aggr.filter_expr().to_vec(),
62            aggr.input().to_owned(),
63            aggr.input_schema(),
64        )
65        .expect("Unable to copy Aggregate!")
66        .with_limit(Some(limit));
67        Some(Arc::new(new_aggr))
68    }
69
70    /// transform_limit matches an `AggregateExec` as the child of a `LocalLimitExec`
71    /// or `GlobalLimitExec` and pushes the limit into the aggregation as a soft limit when
72    /// there is a group by, but no sorting, no aggregate expressions, and no filters in the
73    /// aggregation
74    fn transform_limit(plan: Arc<dyn ExecutionPlan>) -> Option<Arc<dyn ExecutionPlan>> {
75        let limit: usize;
76        let mut global_fetch: Option<usize> = None;
77        let mut global_skip: usize = 0;
78        let children: Vec<Arc<dyn ExecutionPlan>>;
79        let mut is_global_limit = false;
80        if let Some(local_limit) = plan.as_any().downcast_ref::<LocalLimitExec>() {
81            limit = local_limit.fetch();
82            children = local_limit.children().into_iter().cloned().collect();
83        } else if let Some(global_limit) = plan.as_any().downcast_ref::<GlobalLimitExec>()
84        {
85            global_fetch = global_limit.fetch();
86            global_fetch?;
87            global_skip = global_limit.skip();
88            // the aggregate must read at least fetch+skip number of rows
89            limit = global_fetch.unwrap() + global_skip;
90            children = global_limit.children().into_iter().cloned().collect();
91            is_global_limit = true
92        } else {
93            return None;
94        }
95        let child = children.iter().exactly_one().ok()?;
96        // ensure there is no output ordering; can this rule be relaxed?
97        if plan.output_ordering().is_some() {
98            return None;
99        }
100        // ensure no ordering is required on the input
101        if plan.required_input_ordering()[0].is_some() {
102            return None;
103        }
104
105        // if found_match_aggr is true, match_aggr holds a parent aggregation whose group_by
106        // must match that of a child aggregation in order to rewrite the child aggregation
107        let mut match_aggr: Arc<dyn ExecutionPlan> = plan;
108        let mut found_match_aggr = false;
109
110        let mut rewrite_applicable = true;
111        let closure = |plan: Arc<dyn ExecutionPlan>| {
112            if !rewrite_applicable {
113                return Ok(Transformed::no(plan));
114            }
115            if let Some(aggr) = plan.as_any().downcast_ref::<AggregateExec>() {
116                if found_match_aggr
117                    && let Some(parent_aggr) =
118                        match_aggr.as_any().downcast_ref::<AggregateExec>()
119                    && !parent_aggr.group_expr().eq(aggr.group_expr())
120                {
121                    // a partial and final aggregation with different groupings disqualifies
122                    // rewriting the child aggregation
123                    rewrite_applicable = false;
124                    return Ok(Transformed::no(plan));
125                }
126                // either we run into an Aggregate and transform it, or disable the rewrite
127                // for subsequent children
128                match Self::transform_agg(aggr, limit) {
129                    None => {}
130                    Some(new_aggr) => {
131                        match_aggr = plan;
132                        found_match_aggr = true;
133                        return Ok(Transformed::yes(new_aggr));
134                    }
135                }
136            }
137            rewrite_applicable = false;
138            Ok(Transformed::no(plan))
139        };
140        let child = child.to_owned().transform_down(closure).data().ok()?;
141        if is_global_limit {
142            return Some(Arc::new(GlobalLimitExec::new(
143                child,
144                global_skip,
145                global_fetch,
146            )));
147        }
148        Some(Arc::new(LocalLimitExec::new(child, limit)))
149    }
150}
151
152impl Default for LimitedDistinctAggregation {
153    fn default() -> Self {
154        Self::new()
155    }
156}
157
158impl PhysicalOptimizerRule for LimitedDistinctAggregation {
159    fn optimize(
160        &self,
161        plan: Arc<dyn ExecutionPlan>,
162        config: &ConfigOptions,
163    ) -> Result<Arc<dyn ExecutionPlan>> {
164        if config.optimizer.enable_distinct_aggregation_soft_limit {
165            plan.transform_down(|plan| {
166                Ok(
167                    if let Some(plan) =
168                        LimitedDistinctAggregation::transform_limit(plan.to_owned())
169                    {
170                        Transformed::yes(plan)
171                    } else {
172                        Transformed::no(plan)
173                    },
174                )
175            })
176            .data()
177        } else {
178            Ok(plan)
179        }
180    }
181
182    fn name(&self) -> &str {
183        "LimitedDistinctAggregation"
184    }
185
186    fn schema_check(&self) -> bool {
187        true
188    }
189}
190
191// See tests in datafusion/core/tests/physical_optimizer