datafusion_physical_optimizer/
ensure_coop.rs1use std::fmt::{Debug, Formatter};
24use std::sync::Arc;
25
26use crate::PhysicalOptimizerRule;
27
28use datafusion_common::Result;
29use datafusion_common::config::ConfigOptions;
30use datafusion_common::tree_node::{Transformed, TreeNode};
31use datafusion_physical_plan::ExecutionPlan;
32use datafusion_physical_plan::coop::CooperativeExec;
33use datafusion_physical_plan::execution_plan::{EvaluationType, SchedulingType};
34
35pub struct EnsureCooperative {}
41
42impl EnsureCooperative {
43 pub fn new() -> Self {
44 Self {}
45 }
46}
47
48impl Default for EnsureCooperative {
49 fn default() -> Self {
50 Self::new()
51 }
52}
53
54impl Debug for EnsureCooperative {
55 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
56 f.debug_struct(self.name()).finish()
57 }
58}
59
60impl PhysicalOptimizerRule for EnsureCooperative {
61 fn name(&self) -> &str {
62 "EnsureCooperative"
63 }
64
65 fn optimize(
66 &self,
67 plan: Arc<dyn ExecutionPlan>,
68 _config: &ConfigOptions,
69 ) -> Result<Arc<dyn ExecutionPlan>> {
70 use std::cell::RefCell;
71
72 let ancestry_stack = RefCell::new(Vec::<(SchedulingType, EvaluationType)>::new());
73
74 plan.transform_down_up(
75 |plan| {
77 let props = plan.properties();
78 ancestry_stack
79 .borrow_mut()
80 .push((props.scheduling_type, props.evaluation_type));
81 Ok(Transformed::no(plan))
82 },
83 |plan| {
85 ancestry_stack.borrow_mut().pop();
86
87 let props = plan.properties();
88 let is_cooperative = props.scheduling_type == SchedulingType::Cooperative;
89 let is_leaf = plan.children().is_empty();
90 let is_exchange = props.evaluation_type == EvaluationType::Eager;
91
92 let mut is_under_cooperative_context = false;
93 for (scheduling_type, evaluation_type) in
94 ancestry_stack.borrow().iter().rev()
95 {
96 if *scheduling_type == SchedulingType::Cooperative {
98 is_under_cooperative_context = true;
99 break;
100 } else if *evaluation_type == EvaluationType::Eager {
102 is_under_cooperative_context = false;
103 break;
104 }
105 }
106
107 if (is_leaf || is_exchange)
112 && !is_cooperative
113 && !is_under_cooperative_context
114 {
115 return Ok(Transformed::yes(Arc::new(CooperativeExec::new(plan))));
116 }
117
118 Ok(Transformed::no(plan))
119 },
120 )
121 .map(|t| t.data)
122 }
123
124 fn schema_check(&self) -> bool {
125 true
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use datafusion_physical_plan::{displayable, test::scan_partitioned};
134 use insta::assert_snapshot;
135
136 #[tokio::test]
137 async fn test_cooperative_exec_for_custom_exec() {
138 let test_custom_exec = scan_partitioned(1);
139 let config = ConfigOptions::new();
140 let optimized = EnsureCooperative::new()
141 .optimize(test_custom_exec, &config)
142 .unwrap();
143
144 let display = displayable(optimized.as_ref()).indent(true).to_string();
145 assert_snapshot!(display, @r"
147 CooperativeExec
148 DataSourceExec: partitions=1, partition_sizes=[1]
149 ");
150 }
151
152 #[tokio::test]
153 async fn test_optimizer_is_idempotent() {
154 let config = ConfigOptions::new();
161 let rule = EnsureCooperative::new();
162
163 let unwrapped_plan = scan_partitioned(1);
165 let mut current = unwrapped_plan;
166 let mut stable_result = String::new();
167
168 for run in 1..=5 {
169 current = rule.optimize(current, &config).unwrap();
170 let display = displayable(current.as_ref()).indent(true).to_string();
171
172 if run == 1 {
173 stable_result = display.clone();
174 assert_eq!(display.matches("CooperativeExec").count(), 1);
175 } else {
176 assert_eq!(
177 display, stable_result,
178 "Run {run} should match run 1 (idempotent)"
179 );
180 assert_eq!(
181 display.matches("CooperativeExec").count(),
182 1,
183 "Should always have exactly 1 CooperativeExec, not accumulate"
184 );
185 }
186 }
187
188 let pre_wrapped = Arc::new(CooperativeExec::new(scan_partitioned(1)));
190 let result = rule.optimize(pre_wrapped, &config).unwrap();
191 let display = displayable(result.as_ref()).indent(true).to_string();
192
193 assert_eq!(
194 display.matches("CooperativeExec").count(),
195 1,
196 "Should not double-wrap already cooperative plans"
197 );
198 assert_eq!(
199 display, stable_result,
200 "Pre-wrapped plan should produce same result as unwrapped after optimization"
201 );
202 }
203
204 #[tokio::test]
205 async fn test_selective_wrapping() {
206 use datafusion_physical_expr::expressions::lit;
209 use datafusion_physical_plan::filter::FilterExec;
210
211 let config = ConfigOptions::new();
212 let rule = EnsureCooperative::new();
213
214 let scan = scan_partitioned(1);
216 let filter = Arc::new(FilterExec::try_new(lit(true), scan).unwrap());
217 let optimized = rule.optimize(filter, &config).unwrap();
218 let display = displayable(optimized.as_ref()).indent(true).to_string();
219
220 assert_eq!(display.matches("CooperativeExec").count(), 1);
221 assert!(display.contains("FilterExec"));
222
223 let scan2 = scan_partitioned(1);
225 let wrapped_scan = Arc::new(CooperativeExec::new(scan2));
226 let filter2 = Arc::new(FilterExec::try_new(lit(true), wrapped_scan).unwrap());
227 let optimized2 = rule.optimize(filter2, &config).unwrap();
228 let display2 = displayable(optimized2.as_ref()).indent(true).to_string();
229
230 assert_eq!(display2.matches("CooperativeExec").count(), 1);
231 }
232
233 #[tokio::test]
234 async fn test_multiple_leaf_nodes() {
235 use datafusion_physical_plan::union::UnionExec;
237
238 let scan1 = scan_partitioned(1);
239 let scan2 = scan_partitioned(1);
240 let union = UnionExec::try_new(vec![scan1, scan2]).unwrap();
241
242 let config = ConfigOptions::new();
243 let optimized = EnsureCooperative::new()
244 .optimize(union as Arc<dyn ExecutionPlan>, &config)
245 .unwrap();
246
247 let display = displayable(optimized.as_ref()).indent(true).to_string();
248
249 assert_eq!(
251 display.matches("CooperativeExec").count(),
252 2,
253 "Each leaf node should be wrapped separately"
254 );
255 assert_eq!(
256 display.matches("DataSourceExec").count(),
257 2,
258 "Both data sources should be present"
259 );
260 }
261
262 #[tokio::test]
263 async fn test_eager_evaluation_resets_cooperative_context() {
264 use arrow::datatypes::Schema;
266 use datafusion_common::internal_err;
267 use datafusion_execution::TaskContext;
268 use datafusion_physical_expr::EquivalenceProperties;
269 use datafusion_physical_plan::{
270 DisplayAs, DisplayFormatType, Partitioning, PlanProperties,
271 SendableRecordBatchStream,
272 execution_plan::{Boundedness, EmissionType},
273 };
274
275 #[derive(Debug)]
276 struct DummyExec {
277 name: String,
278 input: Arc<dyn ExecutionPlan>,
279 scheduling_type: SchedulingType,
280 evaluation_type: EvaluationType,
281 properties: Arc<PlanProperties>,
282 }
283
284 impl DummyExec {
285 fn new(
286 name: &str,
287 input: Arc<dyn ExecutionPlan>,
288 scheduling_type: SchedulingType,
289 evaluation_type: EvaluationType,
290 ) -> Self {
291 let properties = PlanProperties::new(
292 EquivalenceProperties::new(Arc::new(Schema::empty())),
293 Partitioning::UnknownPartitioning(1),
294 EmissionType::Incremental,
295 Boundedness::Bounded,
296 )
297 .with_scheduling_type(scheduling_type)
298 .with_evaluation_type(evaluation_type);
299
300 Self {
301 name: name.to_string(),
302 input,
303 scheduling_type,
304 evaluation_type,
305 properties: Arc::new(properties),
306 }
307 }
308 }
309
310 impl DisplayAs for DummyExec {
311 fn fmt_as(
312 &self,
313 _: DisplayFormatType,
314 f: &mut Formatter,
315 ) -> std::fmt::Result {
316 write!(f, "{}", self.name)
317 }
318 }
319
320 impl ExecutionPlan for DummyExec {
321 fn name(&self) -> &str {
322 &self.name
323 }
324 fn properties(&self) -> &Arc<PlanProperties> {
325 &self.properties
326 }
327 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
328 vec![&self.input]
329 }
330 fn with_new_children(
331 self: Arc<Self>,
332 children: Vec<Arc<dyn ExecutionPlan>>,
333 ) -> Result<Arc<dyn ExecutionPlan>> {
334 Ok(Arc::new(DummyExec::new(
335 &self.name,
336 Arc::clone(&children[0]),
337 self.scheduling_type,
338 self.evaluation_type,
339 )))
340 }
341 fn execute(
342 &self,
343 _: usize,
344 _: Arc<TaskContext>,
345 ) -> Result<SendableRecordBatchStream> {
346 internal_err!("DummyExec does not support execution")
347 }
348 }
349
350 let scan = scan_partitioned(1);
353 let exch1 = Arc::new(DummyExec::new(
354 "exch1",
355 scan,
356 SchedulingType::NonCooperative,
357 EvaluationType::Eager,
358 ));
359 let coop = Arc::new(CooperativeExec::new(exch1));
360 let filter1 = Arc::new(DummyExec::new(
361 "filter1",
362 coop,
363 SchedulingType::NonCooperative,
364 EvaluationType::Lazy,
365 ));
366 let exch2 = Arc::new(DummyExec::new(
367 "exch2",
368 filter1,
369 SchedulingType::Cooperative,
370 EvaluationType::Eager,
371 ));
372 let filter2 = Arc::new(DummyExec::new(
373 "filter2",
374 exch2,
375 SchedulingType::NonCooperative,
376 EvaluationType::Lazy,
377 ));
378
379 let config = ConfigOptions::new();
380 let optimized = EnsureCooperative::new().optimize(filter2, &config).unwrap();
381
382 let display = displayable(optimized.as_ref()).indent(true).to_string();
383
384 assert_eq!(
391 display.matches("CooperativeExec").count(),
392 2,
393 "Should have 2 CooperativeExec: one wrapping scan, one wrapping exch1"
394 );
395
396 assert_snapshot!(display, @r"
397 filter2
398 exch2
399 filter1
400 CooperativeExec
401 exch1
402 CooperativeExec
403 DataSourceExec: partitions=1, partition_sizes=[1]
404 ");
405 }
406}