Skip to main content

datafusion_physical_optimizer/
ensure_coop.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! The [`EnsureCooperative`] optimizer rule inspects the physical plan to find all
19//! portions of the plan that will not yield cooperatively.
20//! It will insert `CooperativeExec` nodes where appropriate to ensure execution plans
21//! always yield cooperatively.
22
23use std::fmt::{Debug, Formatter};
24use std::sync::Arc;
25
26use crate::PhysicalOptimizerRule;
27
28use datafusion_common::Result;
29use datafusion_common::config::ConfigOptions;
30use datafusion_common::tree_node::{Transformed, TreeNode};
31use datafusion_physical_plan::ExecutionPlan;
32use datafusion_physical_plan::coop::CooperativeExec;
33use datafusion_physical_plan::execution_plan::{EvaluationType, SchedulingType};
34
35/// `EnsureCooperative` is a [`PhysicalOptimizerRule`] that inspects the physical plan for
36/// sub plans that do not participate in cooperative scheduling. The plan is subdivided into sub
37/// plans on eager evaluation boundaries. Leaf nodes and eager evaluation roots are checked
38/// to see if they participate in cooperative scheduling. Those that do no are wrapped in
39/// a [`CooperativeExec`] parent.
40pub struct EnsureCooperative {}
41
42impl EnsureCooperative {
43    pub fn new() -> Self {
44        Self {}
45    }
46}
47
48impl Default for EnsureCooperative {
49    fn default() -> Self {
50        Self::new()
51    }
52}
53
54impl Debug for EnsureCooperative {
55    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
56        f.debug_struct(self.name()).finish()
57    }
58}
59
60impl PhysicalOptimizerRule for EnsureCooperative {
61    fn name(&self) -> &str {
62        "EnsureCooperative"
63    }
64
65    fn optimize(
66        &self,
67        plan: Arc<dyn ExecutionPlan>,
68        _config: &ConfigOptions,
69    ) -> Result<Arc<dyn ExecutionPlan>> {
70        use std::cell::RefCell;
71
72        let ancestry_stack = RefCell::new(Vec::<(SchedulingType, EvaluationType)>::new());
73
74        plan.transform_down_up(
75            // Down phase: Push parent properties <SchedulingType, EvaluationType> into the stack
76            |plan| {
77                let props = plan.properties();
78                ancestry_stack
79                    .borrow_mut()
80                    .push((props.scheduling_type, props.evaluation_type));
81                Ok(Transformed::no(plan))
82            },
83            // Up phase: Wrap nodes with CooperativeExec if needed
84            |plan| {
85                ancestry_stack.borrow_mut().pop();
86
87                let props = plan.properties();
88                let is_cooperative = props.scheduling_type == SchedulingType::Cooperative;
89                let is_leaf = plan.children().is_empty();
90                let is_exchange = props.evaluation_type == EvaluationType::Eager;
91
92                let mut is_under_cooperative_context = false;
93                for (scheduling_type, evaluation_type) in
94                    ancestry_stack.borrow().iter().rev()
95                {
96                    // If nearest ancestor is cooperative, we are under a cooperative context
97                    if *scheduling_type == SchedulingType::Cooperative {
98                        is_under_cooperative_context = true;
99                        break;
100                    // If nearest ancestor is eager, the cooperative context will be reset
101                    } else if *evaluation_type == EvaluationType::Eager {
102                        is_under_cooperative_context = false;
103                        break;
104                    }
105                }
106
107                // Wrap if:
108                // 1. Node is a leaf or exchange point
109                // 2. Node is not already cooperative
110                // 3. Not under any Cooperative context
111                if (is_leaf || is_exchange)
112                    && !is_cooperative
113                    && !is_under_cooperative_context
114                {
115                    return Ok(Transformed::yes(Arc::new(CooperativeExec::new(plan))));
116                }
117
118                Ok(Transformed::no(plan))
119            },
120        )
121        .map(|t| t.data)
122    }
123
124    fn schema_check(&self) -> bool {
125        // Wrapping a leaf in YieldStreamExec preserves the schema, so it is safe.
126        true
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use datafusion_physical_plan::{displayable, test::scan_partitioned};
134    use insta::assert_snapshot;
135
136    #[tokio::test]
137    async fn test_cooperative_exec_for_custom_exec() {
138        let test_custom_exec = scan_partitioned(1);
139        let config = ConfigOptions::new();
140        let optimized = EnsureCooperative::new()
141            .optimize(test_custom_exec, &config)
142            .unwrap();
143
144        let display = displayable(optimized.as_ref()).indent(true).to_string();
145        // Use insta snapshot to ensure full plan structure
146        assert_snapshot!(display, @r"
147        CooperativeExec
148          DataSourceExec: partitions=1, partition_sizes=[1]
149        ");
150    }
151
152    #[tokio::test]
153    async fn test_optimizer_is_idempotent() {
154        // Comprehensive idempotency test: verify f(f(...f(x))) = f(x)
155        // This test covers:
156        // 1. Multiple runs on unwrapped plan
157        // 2. Multiple runs on already-wrapped plan
158        // 3. No accumulation of CooperativeExec nodes
159
160        let config = ConfigOptions::new();
161        let rule = EnsureCooperative::new();
162
163        // Test 1: Start with unwrapped plan, run multiple times
164        let unwrapped_plan = scan_partitioned(1);
165        let mut current = unwrapped_plan;
166        let mut stable_result = String::new();
167
168        for run in 1..=5 {
169            current = rule.optimize(current, &config).unwrap();
170            let display = displayable(current.as_ref()).indent(true).to_string();
171
172            if run == 1 {
173                stable_result = display.clone();
174                assert_eq!(display.matches("CooperativeExec").count(), 1);
175            } else {
176                assert_eq!(
177                    display, stable_result,
178                    "Run {run} should match run 1 (idempotent)"
179                );
180                assert_eq!(
181                    display.matches("CooperativeExec").count(),
182                    1,
183                    "Should always have exactly 1 CooperativeExec, not accumulate"
184                );
185            }
186        }
187
188        // Test 2: Start with already-wrapped plan, verify no double wrapping
189        let pre_wrapped = Arc::new(CooperativeExec::new(scan_partitioned(1)));
190        let result = rule.optimize(pre_wrapped, &config).unwrap();
191        let display = displayable(result.as_ref()).indent(true).to_string();
192
193        assert_eq!(
194            display.matches("CooperativeExec").count(),
195            1,
196            "Should not double-wrap already cooperative plans"
197        );
198        assert_eq!(
199            display, stable_result,
200            "Pre-wrapped plan should produce same result as unwrapped after optimization"
201        );
202    }
203
204    #[tokio::test]
205    async fn test_selective_wrapping() {
206        // Test that wrapping is selective: only leaf/eager nodes, not intermediate nodes
207        // Also verify depth tracking prevents double wrapping in subtrees
208        use datafusion_physical_expr::expressions::lit;
209        use datafusion_physical_plan::filter::FilterExec;
210
211        let config = ConfigOptions::new();
212        let rule = EnsureCooperative::new();
213
214        // Case 1: Filter -> Scan (middle node should not be wrapped)
215        let scan = scan_partitioned(1);
216        let filter = Arc::new(FilterExec::try_new(lit(true), scan).unwrap());
217        let optimized = rule.optimize(filter, &config).unwrap();
218        let display = displayable(optimized.as_ref()).indent(true).to_string();
219
220        assert_eq!(display.matches("CooperativeExec").count(), 1);
221        assert!(display.contains("FilterExec"));
222
223        // Case 2: Filter -> CoopExec -> Scan (depth tracking prevents double wrap)
224        let scan2 = scan_partitioned(1);
225        let wrapped_scan = Arc::new(CooperativeExec::new(scan2));
226        let filter2 = Arc::new(FilterExec::try_new(lit(true), wrapped_scan).unwrap());
227        let optimized2 = rule.optimize(filter2, &config).unwrap();
228        let display2 = displayable(optimized2.as_ref()).indent(true).to_string();
229
230        assert_eq!(display2.matches("CooperativeExec").count(), 1);
231    }
232
233    #[tokio::test]
234    async fn test_multiple_leaf_nodes() {
235        // When there are multiple leaf nodes, each should be wrapped separately
236        use datafusion_physical_plan::union::UnionExec;
237
238        let scan1 = scan_partitioned(1);
239        let scan2 = scan_partitioned(1);
240        let union = UnionExec::try_new(vec![scan1, scan2]).unwrap();
241
242        let config = ConfigOptions::new();
243        let optimized = EnsureCooperative::new()
244            .optimize(union as Arc<dyn ExecutionPlan>, &config)
245            .unwrap();
246
247        let display = displayable(optimized.as_ref()).indent(true).to_string();
248
249        // Each leaf should have its own CooperativeExec
250        assert_eq!(
251            display.matches("CooperativeExec").count(),
252            2,
253            "Each leaf node should be wrapped separately"
254        );
255        assert_eq!(
256            display.matches("DataSourceExec").count(),
257            2,
258            "Both data sources should be present"
259        );
260    }
261
262    #[tokio::test]
263    async fn test_eager_evaluation_resets_cooperative_context() {
264        // Test that cooperative context is reset when encountering an eager evaluation boundary.
265        use arrow::datatypes::Schema;
266        use datafusion_common::internal_err;
267        use datafusion_execution::TaskContext;
268        use datafusion_physical_expr::EquivalenceProperties;
269        use datafusion_physical_plan::{
270            DisplayAs, DisplayFormatType, Partitioning, PlanProperties,
271            SendableRecordBatchStream,
272            execution_plan::{Boundedness, EmissionType},
273        };
274
275        #[derive(Debug)]
276        struct DummyExec {
277            name: String,
278            input: Arc<dyn ExecutionPlan>,
279            scheduling_type: SchedulingType,
280            evaluation_type: EvaluationType,
281            properties: Arc<PlanProperties>,
282        }
283
284        impl DummyExec {
285            fn new(
286                name: &str,
287                input: Arc<dyn ExecutionPlan>,
288                scheduling_type: SchedulingType,
289                evaluation_type: EvaluationType,
290            ) -> Self {
291                let properties = PlanProperties::new(
292                    EquivalenceProperties::new(Arc::new(Schema::empty())),
293                    Partitioning::UnknownPartitioning(1),
294                    EmissionType::Incremental,
295                    Boundedness::Bounded,
296                )
297                .with_scheduling_type(scheduling_type)
298                .with_evaluation_type(evaluation_type);
299
300                Self {
301                    name: name.to_string(),
302                    input,
303                    scheduling_type,
304                    evaluation_type,
305                    properties: Arc::new(properties),
306                }
307            }
308        }
309
310        impl DisplayAs for DummyExec {
311            fn fmt_as(
312                &self,
313                _: DisplayFormatType,
314                f: &mut Formatter,
315            ) -> std::fmt::Result {
316                write!(f, "{}", self.name)
317            }
318        }
319
320        impl ExecutionPlan for DummyExec {
321            fn name(&self) -> &str {
322                &self.name
323            }
324            fn properties(&self) -> &Arc<PlanProperties> {
325                &self.properties
326            }
327            fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
328                vec![&self.input]
329            }
330            fn with_new_children(
331                self: Arc<Self>,
332                children: Vec<Arc<dyn ExecutionPlan>>,
333            ) -> Result<Arc<dyn ExecutionPlan>> {
334                Ok(Arc::new(DummyExec::new(
335                    &self.name,
336                    Arc::clone(&children[0]),
337                    self.scheduling_type,
338                    self.evaluation_type,
339                )))
340            }
341            fn execute(
342                &self,
343                _: usize,
344                _: Arc<TaskContext>,
345            ) -> Result<SendableRecordBatchStream> {
346                internal_err!("DummyExec does not support execution")
347            }
348        }
349
350        // Build a plan similar to the original test:
351        // scan -> exch1(NonCoop,Eager) -> CoopExec -> filter -> exch2(Coop,Eager) -> filter
352        let scan = scan_partitioned(1);
353        let exch1 = Arc::new(DummyExec::new(
354            "exch1",
355            scan,
356            SchedulingType::NonCooperative,
357            EvaluationType::Eager,
358        ));
359        let coop = Arc::new(CooperativeExec::new(exch1));
360        let filter1 = Arc::new(DummyExec::new(
361            "filter1",
362            coop,
363            SchedulingType::NonCooperative,
364            EvaluationType::Lazy,
365        ));
366        let exch2 = Arc::new(DummyExec::new(
367            "exch2",
368            filter1,
369            SchedulingType::Cooperative,
370            EvaluationType::Eager,
371        ));
372        let filter2 = Arc::new(DummyExec::new(
373            "filter2",
374            exch2,
375            SchedulingType::NonCooperative,
376            EvaluationType::Lazy,
377        ));
378
379        let config = ConfigOptions::new();
380        let optimized = EnsureCooperative::new().optimize(filter2, &config).unwrap();
381
382        let display = displayable(optimized.as_ref()).indent(true).to_string();
383
384        // Expected wrapping:
385        // - Scan (leaf) gets wrapped
386        // - exch1 (eager+noncoop) keeps its manual CooperativeExec wrapper
387        // - filter1 is protected by exch2's cooperative context, no extra wrap
388        // - exch2 (already Cooperative) does NOT get wrapped
389        // - filter2 (not leaf or eager) does NOT get wrapped
390        assert_eq!(
391            display.matches("CooperativeExec").count(),
392            2,
393            "Should have 2 CooperativeExec: one wrapping scan, one wrapping exch1"
394        );
395
396        assert_snapshot!(display, @r"
397        filter2
398          exch2
399            filter1
400              CooperativeExec
401                exch1
402                  CooperativeExec
403                    DataSourceExec: partitions=1, partition_sizes=[1]
404        ");
405    }
406}