datafusion_physical_optimizer/
output_requirements.rs1use std::sync::Arc;
26
27use crate::PhysicalOptimizerRule;
28
29use datafusion_common::config::ConfigOptions;
30use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
31use datafusion_common::{Result, Statistics};
32use datafusion_execution::TaskContext;
33use datafusion_physical_expr::{Distribution, LexRequirement, PhysicalSortRequirement};
34use datafusion_physical_plan::projection::{
35 make_with_child, update_expr, ProjectionExec,
36};
37use datafusion_physical_plan::sorts::sort::SortExec;
38use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
39use datafusion_physical_plan::{
40 DisplayAs, DisplayFormatType, ExecutionPlan, SendableRecordBatchStream,
41};
42use datafusion_physical_plan::{ExecutionPlanProperties, PlanProperties};
43
44#[derive(Debug)]
54pub struct OutputRequirements {
55 mode: RuleMode,
56}
57
58impl OutputRequirements {
59 pub fn new_add_mode() -> Self {
64 Self {
65 mode: RuleMode::Add,
66 }
67 }
68
69 pub fn new_remove_mode() -> Self {
76 Self {
77 mode: RuleMode::Remove,
78 }
79 }
80}
81
82#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Hash)]
83enum RuleMode {
84 Add,
85 Remove,
86}
87
88#[derive(Debug)]
95pub struct OutputRequirementExec {
96 input: Arc<dyn ExecutionPlan>,
97 order_requirement: Option<LexRequirement>,
98 dist_requirement: Distribution,
99 cache: PlanProperties,
100}
101
102impl OutputRequirementExec {
103 pub fn new(
104 input: Arc<dyn ExecutionPlan>,
105 requirements: Option<LexRequirement>,
106 dist_requirement: Distribution,
107 ) -> Self {
108 let cache = Self::compute_properties(&input);
109 Self {
110 input,
111 order_requirement: requirements,
112 dist_requirement,
113 cache,
114 }
115 }
116
117 pub fn input(&self) -> Arc<dyn ExecutionPlan> {
118 Arc::clone(&self.input)
119 }
120
121 fn compute_properties(input: &Arc<dyn ExecutionPlan>) -> PlanProperties {
123 PlanProperties::new(
124 input.equivalence_properties().clone(), input.output_partitioning().clone(), input.pipeline_behavior(), input.boundedness(), )
129 }
130}
131
132impl DisplayAs for OutputRequirementExec {
133 fn fmt_as(
134 &self,
135 t: DisplayFormatType,
136 f: &mut std::fmt::Formatter,
137 ) -> std::fmt::Result {
138 match t {
139 DisplayFormatType::Default | DisplayFormatType::Verbose => {
140 write!(f, "OutputRequirementExec")
141 }
142 DisplayFormatType::TreeRender => {
143 write!(f, "")
145 }
146 }
147 }
148}
149
150impl ExecutionPlan for OutputRequirementExec {
151 fn name(&self) -> &'static str {
152 "OutputRequirementExec"
153 }
154
155 fn as_any(&self) -> &dyn std::any::Any {
156 self
157 }
158
159 fn properties(&self) -> &PlanProperties {
160 &self.cache
161 }
162
163 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
164 vec![false]
165 }
166
167 fn required_input_distribution(&self) -> Vec<Distribution> {
168 vec![self.dist_requirement.clone()]
169 }
170
171 fn maintains_input_order(&self) -> Vec<bool> {
172 vec![true]
173 }
174
175 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
176 vec![&self.input]
177 }
178
179 fn required_input_ordering(&self) -> Vec<Option<LexRequirement>> {
180 vec![self.order_requirement.clone()]
181 }
182
183 fn with_new_children(
184 self: Arc<Self>,
185 mut children: Vec<Arc<dyn ExecutionPlan>>,
186 ) -> Result<Arc<dyn ExecutionPlan>> {
187 Ok(Arc::new(Self::new(
188 children.remove(0), self.order_requirement.clone(),
190 self.dist_requirement.clone(),
191 )))
192 }
193
194 fn execute(
195 &self,
196 _partition: usize,
197 _context: Arc<TaskContext>,
198 ) -> Result<SendableRecordBatchStream> {
199 unreachable!();
200 }
201
202 fn statistics(&self) -> Result<Statistics> {
203 self.input.partition_statistics(None)
204 }
205
206 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
207 self.input.partition_statistics(partition)
208 }
209
210 fn try_swapping_with_projection(
211 &self,
212 projection: &ProjectionExec,
213 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
214 if projection.expr().len() >= projection.input().schema().fields().len() {
216 return Ok(None);
217 }
218
219 let mut updated_sort_reqs = LexRequirement::new(vec![]);
220 if let Some(reqs) = &self.required_input_ordering()[0] {
222 for req in &reqs.inner {
223 let Some(new_expr) = update_expr(&req.expr, projection.expr(), false)?
224 else {
225 return Ok(None);
226 };
227 updated_sort_reqs.push(PhysicalSortRequirement {
228 expr: new_expr,
229 options: req.options,
230 });
231 }
232 }
233
234 let dist_req = match &self.required_input_distribution()[0] {
235 Distribution::HashPartitioned(exprs) => {
236 let mut updated_exprs = vec![];
237 for expr in exprs {
238 let Some(new_expr) = update_expr(expr, projection.expr(), false)?
239 else {
240 return Ok(None);
241 };
242 updated_exprs.push(new_expr);
243 }
244 Distribution::HashPartitioned(updated_exprs)
245 }
246 dist => dist.clone(),
247 };
248
249 make_with_child(projection, &self.input())
250 .map(|input| {
251 OutputRequirementExec::new(
252 input,
253 (!updated_sort_reqs.is_empty()).then_some(updated_sort_reqs),
254 dist_req,
255 )
256 })
257 .map(|e| Some(Arc::new(e) as _))
258 }
259}
260
261impl PhysicalOptimizerRule for OutputRequirements {
262 fn optimize(
263 &self,
264 plan: Arc<dyn ExecutionPlan>,
265 _config: &ConfigOptions,
266 ) -> Result<Arc<dyn ExecutionPlan>> {
267 match self.mode {
268 RuleMode::Add => require_top_ordering(plan),
269 RuleMode::Remove => plan
270 .transform_up(|plan| {
271 if let Some(sort_req) =
272 plan.as_any().downcast_ref::<OutputRequirementExec>()
273 {
274 Ok(Transformed::yes(sort_req.input()))
275 } else {
276 Ok(Transformed::no(plan))
277 }
278 })
279 .data(),
280 }
281 }
282
283 fn name(&self) -> &str {
284 "OutputRequirements"
285 }
286
287 fn schema_check(&self) -> bool {
288 true
289 }
290}
291
292fn require_top_ordering(plan: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
295 let (new_plan, is_changed) = require_top_ordering_helper(plan)?;
296 if is_changed {
297 Ok(new_plan)
298 } else {
299 Ok(Arc::new(OutputRequirementExec::new(
301 new_plan,
302 None,
304 Distribution::UnspecifiedDistribution,
305 )) as _)
306 }
307}
308
309fn require_top_ordering_helper(
313 plan: Arc<dyn ExecutionPlan>,
314) -> Result<(Arc<dyn ExecutionPlan>, bool)> {
315 let mut children = plan.children();
316 if children.len() != 1 {
318 Ok((plan, false))
319 } else if let Some(sort_exec) = plan.as_any().downcast_ref::<SortExec>() {
320 let req_ordering = sort_exec.expr();
323 let req_dist = sort_exec.required_input_distribution()[0].clone();
324 let reqs = LexRequirement::from(req_ordering.clone());
325 Ok((
326 Arc::new(OutputRequirementExec::new(plan, Some(reqs), req_dist)) as _,
327 true,
328 ))
329 } else if let Some(spm) = plan.as_any().downcast_ref::<SortPreservingMergeExec>() {
330 let reqs = LexRequirement::from(spm.expr().clone());
331 Ok((
332 Arc::new(OutputRequirementExec::new(
333 plan,
334 Some(reqs),
335 Distribution::SinglePartition,
336 )) as _,
337 true,
338 ))
339 } else if plan.maintains_input_order()[0]
340 && plan.required_input_ordering()[0].is_none()
341 {
342 let (new_child, is_changed) =
347 require_top_ordering_helper(Arc::clone(children.swap_remove(0)))?;
348 Ok((plan.with_new_children(vec![new_child])?, is_changed))
349 } else {
350 Ok((plan, false))
352 }
353}
354
355