Skip to main content

datafusion_physical_optimizer/
ensure_coop.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! The [`EnsureCooperative`] optimizer rule inspects the physical plan to find all
19//! portions of the plan that will not yield cooperatively.
20//! It will insert `CooperativeExec` nodes where appropriate to ensure execution plans
21//! always yield cooperatively.
22
23use std::fmt::{Debug, Formatter};
24use std::sync::Arc;
25
26use crate::PhysicalOptimizerRule;
27
28use datafusion_common::Result;
29use datafusion_common::config::ConfigOptions;
30use datafusion_common::tree_node::{Transformed, TreeNode};
31use datafusion_physical_plan::ExecutionPlan;
32use datafusion_physical_plan::coop::CooperativeExec;
33use datafusion_physical_plan::execution_plan::{EvaluationType, SchedulingType};
34
35/// `EnsureCooperative` is a [`PhysicalOptimizerRule`] that inspects the physical plan for
36/// sub plans that do not participate in cooperative scheduling. The plan is subdivided into sub
37/// plans on eager evaluation boundaries. Leaf nodes and eager evaluation roots are checked
38/// to see if they participate in cooperative scheduling. Those that do no are wrapped in
39/// a [`CooperativeExec`] parent.
40pub struct EnsureCooperative {}
41
42impl EnsureCooperative {
43    pub fn new() -> Self {
44        Self {}
45    }
46}
47
48impl Default for EnsureCooperative {
49    fn default() -> Self {
50        Self::new()
51    }
52}
53
54impl Debug for EnsureCooperative {
55    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
56        f.debug_struct(self.name()).finish()
57    }
58}
59
60impl PhysicalOptimizerRule for EnsureCooperative {
61    fn name(&self) -> &str {
62        "EnsureCooperative"
63    }
64
65    fn optimize(
66        &self,
67        plan: Arc<dyn ExecutionPlan>,
68        _config: &ConfigOptions,
69    ) -> Result<Arc<dyn ExecutionPlan>> {
70        use std::cell::RefCell;
71
72        let ancestry_stack = RefCell::new(Vec::<(SchedulingType, EvaluationType)>::new());
73
74        plan.transform_down_up(
75            // Down phase: Push parent properties <SchedulingType, EvaluationType> into the stack
76            |plan| {
77                let props = plan.properties();
78                ancestry_stack
79                    .borrow_mut()
80                    .push((props.scheduling_type, props.evaluation_type));
81                Ok(Transformed::no(plan))
82            },
83            // Up phase: Wrap nodes with CooperativeExec if needed
84            |plan| {
85                ancestry_stack.borrow_mut().pop();
86
87                let props = plan.properties();
88                let is_cooperative = props.scheduling_type == SchedulingType::Cooperative;
89                let is_leaf = plan.children().is_empty();
90                let is_exchange = props.evaluation_type == EvaluationType::Eager;
91
92                let mut is_under_cooperative_context = false;
93                for (scheduling_type, evaluation_type) in
94                    ancestry_stack.borrow().iter().rev()
95                {
96                    // If nearest ancestor is cooperative, we are under a cooperative context
97                    if *scheduling_type == SchedulingType::Cooperative {
98                        is_under_cooperative_context = true;
99                        break;
100                    // If nearest ancestor is eager, the cooperative context will be reset
101                    } else if *evaluation_type == EvaluationType::Eager {
102                        is_under_cooperative_context = false;
103                        break;
104                    }
105                }
106
107                // Wrap if:
108                // 1. Node is a leaf or exchange point
109                // 2. Node is not already cooperative
110                // 3. Not under any Cooperative context
111                if (is_leaf || is_exchange)
112                    && !is_cooperative
113                    && !is_under_cooperative_context
114                {
115                    return Ok(Transformed::yes(Arc::new(CooperativeExec::new(plan))));
116                }
117
118                Ok(Transformed::no(plan))
119            },
120        )
121        .map(|t| t.data)
122    }
123
124    fn schema_check(&self) -> bool {
125        // Wrapping a leaf in YieldStreamExec preserves the schema, so it is safe.
126        true
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use datafusion_common::config::ConfigOptions;
134    use datafusion_physical_plan::{displayable, test::scan_partitioned};
135    use insta::assert_snapshot;
136
137    #[tokio::test]
138    async fn test_cooperative_exec_for_custom_exec() {
139        let test_custom_exec = scan_partitioned(1);
140        let config = ConfigOptions::new();
141        let optimized = EnsureCooperative::new()
142            .optimize(test_custom_exec, &config)
143            .unwrap();
144
145        let display = displayable(optimized.as_ref()).indent(true).to_string();
146        // Use insta snapshot to ensure full plan structure
147        assert_snapshot!(display, @r"
148        CooperativeExec
149          DataSourceExec: partitions=1, partition_sizes=[1]
150        ");
151    }
152
153    #[tokio::test]
154    async fn test_optimizer_is_idempotent() {
155        // Comprehensive idempotency test: verify f(f(...f(x))) = f(x)
156        // This test covers:
157        // 1. Multiple runs on unwrapped plan
158        // 2. Multiple runs on already-wrapped plan
159        // 3. No accumulation of CooperativeExec nodes
160
161        let config = ConfigOptions::new();
162        let rule = EnsureCooperative::new();
163
164        // Test 1: Start with unwrapped plan, run multiple times
165        let unwrapped_plan = scan_partitioned(1);
166        let mut current = unwrapped_plan;
167        let mut stable_result = String::new();
168
169        for run in 1..=5 {
170            current = rule.optimize(current, &config).unwrap();
171            let display = displayable(current.as_ref()).indent(true).to_string();
172
173            if run == 1 {
174                stable_result = display.clone();
175                assert_eq!(display.matches("CooperativeExec").count(), 1);
176            } else {
177                assert_eq!(
178                    display, stable_result,
179                    "Run {run} should match run 1 (idempotent)"
180                );
181                assert_eq!(
182                    display.matches("CooperativeExec").count(),
183                    1,
184                    "Should always have exactly 1 CooperativeExec, not accumulate"
185                );
186            }
187        }
188
189        // Test 2: Start with already-wrapped plan, verify no double wrapping
190        let pre_wrapped = Arc::new(CooperativeExec::new(scan_partitioned(1)));
191        let result = rule.optimize(pre_wrapped, &config).unwrap();
192        let display = displayable(result.as_ref()).indent(true).to_string();
193
194        assert_eq!(
195            display.matches("CooperativeExec").count(),
196            1,
197            "Should not double-wrap already cooperative plans"
198        );
199        assert_eq!(
200            display, stable_result,
201            "Pre-wrapped plan should produce same result as unwrapped after optimization"
202        );
203    }
204
205    #[tokio::test]
206    async fn test_selective_wrapping() {
207        // Test that wrapping is selective: only leaf/eager nodes, not intermediate nodes
208        // Also verify depth tracking prevents double wrapping in subtrees
209        use datafusion_physical_expr::expressions::lit;
210        use datafusion_physical_plan::filter::FilterExec;
211
212        let config = ConfigOptions::new();
213        let rule = EnsureCooperative::new();
214
215        // Case 1: Filter -> Scan (middle node should not be wrapped)
216        let scan = scan_partitioned(1);
217        let filter = Arc::new(FilterExec::try_new(lit(true), scan).unwrap());
218        let optimized = rule.optimize(filter, &config).unwrap();
219        let display = displayable(optimized.as_ref()).indent(true).to_string();
220
221        assert_eq!(display.matches("CooperativeExec").count(), 1);
222        assert!(display.contains("FilterExec"));
223
224        // Case 2: Filter -> CoopExec -> Scan (depth tracking prevents double wrap)
225        let scan2 = scan_partitioned(1);
226        let wrapped_scan = Arc::new(CooperativeExec::new(scan2));
227        let filter2 = Arc::new(FilterExec::try_new(lit(true), wrapped_scan).unwrap());
228        let optimized2 = rule.optimize(filter2, &config).unwrap();
229        let display2 = displayable(optimized2.as_ref()).indent(true).to_string();
230
231        assert_eq!(display2.matches("CooperativeExec").count(), 1);
232    }
233
234    #[tokio::test]
235    async fn test_multiple_leaf_nodes() {
236        // When there are multiple leaf nodes, each should be wrapped separately
237        use datafusion_physical_plan::union::UnionExec;
238
239        let scan1 = scan_partitioned(1);
240        let scan2 = scan_partitioned(1);
241        let union = UnionExec::try_new(vec![scan1, scan2]).unwrap();
242
243        let config = ConfigOptions::new();
244        let optimized = EnsureCooperative::new()
245            .optimize(union as Arc<dyn ExecutionPlan>, &config)
246            .unwrap();
247
248        let display = displayable(optimized.as_ref()).indent(true).to_string();
249
250        // Each leaf should have its own CooperativeExec
251        assert_eq!(
252            display.matches("CooperativeExec").count(),
253            2,
254            "Each leaf node should be wrapped separately"
255        );
256        assert_eq!(
257            display.matches("DataSourceExec").count(),
258            2,
259            "Both data sources should be present"
260        );
261    }
262
263    #[tokio::test]
264    async fn test_eager_evaluation_resets_cooperative_context() {
265        // Test that cooperative context is reset when encountering an eager evaluation boundary.
266        use arrow::datatypes::Schema;
267        use datafusion_common::{Result, internal_err};
268        use datafusion_execution::TaskContext;
269        use datafusion_physical_expr::EquivalenceProperties;
270        use datafusion_physical_plan::{
271            DisplayAs, DisplayFormatType, Partitioning, PlanProperties,
272            SendableRecordBatchStream,
273            execution_plan::{Boundedness, EmissionType},
274        };
275        use std::any::Any;
276        use std::fmt::Formatter;
277
278        #[derive(Debug)]
279        struct DummyExec {
280            name: String,
281            input: Arc<dyn ExecutionPlan>,
282            scheduling_type: SchedulingType,
283            evaluation_type: EvaluationType,
284            properties: Arc<PlanProperties>,
285        }
286
287        impl DummyExec {
288            fn new(
289                name: &str,
290                input: Arc<dyn ExecutionPlan>,
291                scheduling_type: SchedulingType,
292                evaluation_type: EvaluationType,
293            ) -> Self {
294                let properties = PlanProperties::new(
295                    EquivalenceProperties::new(Arc::new(Schema::empty())),
296                    Partitioning::UnknownPartitioning(1),
297                    EmissionType::Incremental,
298                    Boundedness::Bounded,
299                )
300                .with_scheduling_type(scheduling_type)
301                .with_evaluation_type(evaluation_type);
302
303                Self {
304                    name: name.to_string(),
305                    input,
306                    scheduling_type,
307                    evaluation_type,
308                    properties: Arc::new(properties),
309                }
310            }
311        }
312
313        impl DisplayAs for DummyExec {
314            fn fmt_as(
315                &self,
316                _: DisplayFormatType,
317                f: &mut Formatter,
318            ) -> std::fmt::Result {
319                write!(f, "{}", self.name)
320            }
321        }
322
323        impl ExecutionPlan for DummyExec {
324            fn name(&self) -> &str {
325                &self.name
326            }
327            fn as_any(&self) -> &dyn Any {
328                self
329            }
330            fn properties(&self) -> &Arc<PlanProperties> {
331                &self.properties
332            }
333            fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
334                vec![&self.input]
335            }
336            fn with_new_children(
337                self: Arc<Self>,
338                children: Vec<Arc<dyn ExecutionPlan>>,
339            ) -> Result<Arc<dyn ExecutionPlan>> {
340                Ok(Arc::new(DummyExec::new(
341                    &self.name,
342                    Arc::clone(&children[0]),
343                    self.scheduling_type,
344                    self.evaluation_type,
345                )))
346            }
347            fn execute(
348                &self,
349                _: usize,
350                _: Arc<TaskContext>,
351            ) -> Result<SendableRecordBatchStream> {
352                internal_err!("DummyExec does not support execution")
353            }
354        }
355
356        // Build a plan similar to the original test:
357        // scan -> exch1(NonCoop,Eager) -> CoopExec -> filter -> exch2(Coop,Eager) -> filter
358        let scan = scan_partitioned(1);
359        let exch1 = Arc::new(DummyExec::new(
360            "exch1",
361            scan,
362            SchedulingType::NonCooperative,
363            EvaluationType::Eager,
364        ));
365        let coop = Arc::new(CooperativeExec::new(exch1));
366        let filter1 = Arc::new(DummyExec::new(
367            "filter1",
368            coop,
369            SchedulingType::NonCooperative,
370            EvaluationType::Lazy,
371        ));
372        let exch2 = Arc::new(DummyExec::new(
373            "exch2",
374            filter1,
375            SchedulingType::Cooperative,
376            EvaluationType::Eager,
377        ));
378        let filter2 = Arc::new(DummyExec::new(
379            "filter2",
380            exch2,
381            SchedulingType::NonCooperative,
382            EvaluationType::Lazy,
383        ));
384
385        let config = ConfigOptions::new();
386        let optimized = EnsureCooperative::new().optimize(filter2, &config).unwrap();
387
388        let display = displayable(optimized.as_ref()).indent(true).to_string();
389
390        // Expected wrapping:
391        // - Scan (leaf) gets wrapped
392        // - exch1 (eager+noncoop) keeps its manual CooperativeExec wrapper
393        // - filter1 is protected by exch2's cooperative context, no extra wrap
394        // - exch2 (already Cooperative) does NOT get wrapped
395        // - filter2 (not leaf or eager) does NOT get wrapped
396        assert_eq!(
397            display.matches("CooperativeExec").count(),
398            2,
399            "Should have 2 CooperativeExec: one wrapping scan, one wrapping exch1"
400        );
401
402        assert_snapshot!(display, @r"
403        filter2
404          exch2
405            filter1
406              CooperativeExec
407                exch1
408                  CooperativeExec
409                    DataSourceExec: partitions=1, partition_sizes=[1]
410        ");
411    }
412}