datafusion_physical_optimizer/
ensure_coop.rs1use std::fmt::{Debug, Formatter};
24use std::sync::Arc;
25
26use crate::PhysicalOptimizerRule;
27
28use datafusion_common::Result;
29use datafusion_common::config::ConfigOptions;
30use datafusion_common::tree_node::{Transformed, TreeNode};
31use datafusion_physical_plan::ExecutionPlan;
32use datafusion_physical_plan::coop::CooperativeExec;
33use datafusion_physical_plan::execution_plan::{EvaluationType, SchedulingType};
34
35pub struct EnsureCooperative {}
41
42impl EnsureCooperative {
43 pub fn new() -> Self {
44 Self {}
45 }
46}
47
48impl Default for EnsureCooperative {
49 fn default() -> Self {
50 Self::new()
51 }
52}
53
54impl Debug for EnsureCooperative {
55 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
56 f.debug_struct(self.name()).finish()
57 }
58}
59
60impl PhysicalOptimizerRule for EnsureCooperative {
61 fn name(&self) -> &str {
62 "EnsureCooperative"
63 }
64
65 fn optimize(
66 &self,
67 plan: Arc<dyn ExecutionPlan>,
68 _config: &ConfigOptions,
69 ) -> Result<Arc<dyn ExecutionPlan>> {
70 use std::cell::RefCell;
71
72 let ancestry_stack = RefCell::new(Vec::<(SchedulingType, EvaluationType)>::new());
73
74 plan.transform_down_up(
75 |plan| {
77 let props = plan.properties();
78 ancestry_stack
79 .borrow_mut()
80 .push((props.scheduling_type, props.evaluation_type));
81 Ok(Transformed::no(plan))
82 },
83 |plan| {
85 ancestry_stack.borrow_mut().pop();
86
87 let props = plan.properties();
88 let is_cooperative = props.scheduling_type == SchedulingType::Cooperative;
89 let is_leaf = plan.children().is_empty();
90 let is_exchange = props.evaluation_type == EvaluationType::Eager;
91
92 let mut is_under_cooperative_context = false;
93 for (scheduling_type, evaluation_type) in
94 ancestry_stack.borrow().iter().rev()
95 {
96 if *scheduling_type == SchedulingType::Cooperative {
98 is_under_cooperative_context = true;
99 break;
100 } else if *evaluation_type == EvaluationType::Eager {
102 is_under_cooperative_context = false;
103 break;
104 }
105 }
106
107 if (is_leaf || is_exchange)
112 && !is_cooperative
113 && !is_under_cooperative_context
114 {
115 return Ok(Transformed::yes(Arc::new(CooperativeExec::new(plan))));
116 }
117
118 Ok(Transformed::no(plan))
119 },
120 )
121 .map(|t| t.data)
122 }
123
124 fn schema_check(&self) -> bool {
125 true
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use datafusion_common::config::ConfigOptions;
134 use datafusion_physical_plan::{displayable, test::scan_partitioned};
135 use insta::assert_snapshot;
136
137 #[tokio::test]
138 async fn test_cooperative_exec_for_custom_exec() {
139 let test_custom_exec = scan_partitioned(1);
140 let config = ConfigOptions::new();
141 let optimized = EnsureCooperative::new()
142 .optimize(test_custom_exec, &config)
143 .unwrap();
144
145 let display = displayable(optimized.as_ref()).indent(true).to_string();
146 assert_snapshot!(display, @r"
148 CooperativeExec
149 DataSourceExec: partitions=1, partition_sizes=[1]
150 ");
151 }
152
153 #[tokio::test]
154 async fn test_optimizer_is_idempotent() {
155 let config = ConfigOptions::new();
162 let rule = EnsureCooperative::new();
163
164 let unwrapped_plan = scan_partitioned(1);
166 let mut current = unwrapped_plan;
167 let mut stable_result = String::new();
168
169 for run in 1..=5 {
170 current = rule.optimize(current, &config).unwrap();
171 let display = displayable(current.as_ref()).indent(true).to_string();
172
173 if run == 1 {
174 stable_result = display.clone();
175 assert_eq!(display.matches("CooperativeExec").count(), 1);
176 } else {
177 assert_eq!(
178 display, stable_result,
179 "Run {run} should match run 1 (idempotent)"
180 );
181 assert_eq!(
182 display.matches("CooperativeExec").count(),
183 1,
184 "Should always have exactly 1 CooperativeExec, not accumulate"
185 );
186 }
187 }
188
189 let pre_wrapped = Arc::new(CooperativeExec::new(scan_partitioned(1)));
191 let result = rule.optimize(pre_wrapped, &config).unwrap();
192 let display = displayable(result.as_ref()).indent(true).to_string();
193
194 assert_eq!(
195 display.matches("CooperativeExec").count(),
196 1,
197 "Should not double-wrap already cooperative plans"
198 );
199 assert_eq!(
200 display, stable_result,
201 "Pre-wrapped plan should produce same result as unwrapped after optimization"
202 );
203 }
204
205 #[tokio::test]
206 async fn test_selective_wrapping() {
207 use datafusion_physical_expr::expressions::lit;
210 use datafusion_physical_plan::filter::FilterExec;
211
212 let config = ConfigOptions::new();
213 let rule = EnsureCooperative::new();
214
215 let scan = scan_partitioned(1);
217 let filter = Arc::new(FilterExec::try_new(lit(true), scan).unwrap());
218 let optimized = rule.optimize(filter, &config).unwrap();
219 let display = displayable(optimized.as_ref()).indent(true).to_string();
220
221 assert_eq!(display.matches("CooperativeExec").count(), 1);
222 assert!(display.contains("FilterExec"));
223
224 let scan2 = scan_partitioned(1);
226 let wrapped_scan = Arc::new(CooperativeExec::new(scan2));
227 let filter2 = Arc::new(FilterExec::try_new(lit(true), wrapped_scan).unwrap());
228 let optimized2 = rule.optimize(filter2, &config).unwrap();
229 let display2 = displayable(optimized2.as_ref()).indent(true).to_string();
230
231 assert_eq!(display2.matches("CooperativeExec").count(), 1);
232 }
233
234 #[tokio::test]
235 async fn test_multiple_leaf_nodes() {
236 use datafusion_physical_plan::union::UnionExec;
238
239 let scan1 = scan_partitioned(1);
240 let scan2 = scan_partitioned(1);
241 let union = UnionExec::try_new(vec![scan1, scan2]).unwrap();
242
243 let config = ConfigOptions::new();
244 let optimized = EnsureCooperative::new()
245 .optimize(union as Arc<dyn ExecutionPlan>, &config)
246 .unwrap();
247
248 let display = displayable(optimized.as_ref()).indent(true).to_string();
249
250 assert_eq!(
252 display.matches("CooperativeExec").count(),
253 2,
254 "Each leaf node should be wrapped separately"
255 );
256 assert_eq!(
257 display.matches("DataSourceExec").count(),
258 2,
259 "Both data sources should be present"
260 );
261 }
262
263 #[tokio::test]
264 async fn test_eager_evaluation_resets_cooperative_context() {
265 use arrow::datatypes::Schema;
267 use datafusion_common::{Result, internal_err};
268 use datafusion_execution::TaskContext;
269 use datafusion_physical_expr::EquivalenceProperties;
270 use datafusion_physical_plan::{
271 DisplayAs, DisplayFormatType, Partitioning, PlanProperties,
272 SendableRecordBatchStream,
273 execution_plan::{Boundedness, EmissionType},
274 };
275 use std::any::Any;
276 use std::fmt::Formatter;
277
278 #[derive(Debug)]
279 struct DummyExec {
280 name: String,
281 input: Arc<dyn ExecutionPlan>,
282 scheduling_type: SchedulingType,
283 evaluation_type: EvaluationType,
284 properties: Arc<PlanProperties>,
285 }
286
287 impl DummyExec {
288 fn new(
289 name: &str,
290 input: Arc<dyn ExecutionPlan>,
291 scheduling_type: SchedulingType,
292 evaluation_type: EvaluationType,
293 ) -> Self {
294 let properties = PlanProperties::new(
295 EquivalenceProperties::new(Arc::new(Schema::empty())),
296 Partitioning::UnknownPartitioning(1),
297 EmissionType::Incremental,
298 Boundedness::Bounded,
299 )
300 .with_scheduling_type(scheduling_type)
301 .with_evaluation_type(evaluation_type);
302
303 Self {
304 name: name.to_string(),
305 input,
306 scheduling_type,
307 evaluation_type,
308 properties: Arc::new(properties),
309 }
310 }
311 }
312
313 impl DisplayAs for DummyExec {
314 fn fmt_as(
315 &self,
316 _: DisplayFormatType,
317 f: &mut Formatter,
318 ) -> std::fmt::Result {
319 write!(f, "{}", self.name)
320 }
321 }
322
323 impl ExecutionPlan for DummyExec {
324 fn name(&self) -> &str {
325 &self.name
326 }
327 fn as_any(&self) -> &dyn Any {
328 self
329 }
330 fn properties(&self) -> &Arc<PlanProperties> {
331 &self.properties
332 }
333 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
334 vec![&self.input]
335 }
336 fn with_new_children(
337 self: Arc<Self>,
338 children: Vec<Arc<dyn ExecutionPlan>>,
339 ) -> Result<Arc<dyn ExecutionPlan>> {
340 Ok(Arc::new(DummyExec::new(
341 &self.name,
342 Arc::clone(&children[0]),
343 self.scheduling_type,
344 self.evaluation_type,
345 )))
346 }
347 fn execute(
348 &self,
349 _: usize,
350 _: Arc<TaskContext>,
351 ) -> Result<SendableRecordBatchStream> {
352 internal_err!("DummyExec does not support execution")
353 }
354 }
355
356 let scan = scan_partitioned(1);
359 let exch1 = Arc::new(DummyExec::new(
360 "exch1",
361 scan,
362 SchedulingType::NonCooperative,
363 EvaluationType::Eager,
364 ));
365 let coop = Arc::new(CooperativeExec::new(exch1));
366 let filter1 = Arc::new(DummyExec::new(
367 "filter1",
368 coop,
369 SchedulingType::NonCooperative,
370 EvaluationType::Lazy,
371 ));
372 let exch2 = Arc::new(DummyExec::new(
373 "exch2",
374 filter1,
375 SchedulingType::Cooperative,
376 EvaluationType::Eager,
377 ));
378 let filter2 = Arc::new(DummyExec::new(
379 "filter2",
380 exch2,
381 SchedulingType::NonCooperative,
382 EvaluationType::Lazy,
383 ));
384
385 let config = ConfigOptions::new();
386 let optimized = EnsureCooperative::new().optimize(filter2, &config).unwrap();
387
388 let display = displayable(optimized.as_ref()).indent(true).to_string();
389
390 assert_eq!(
397 display.matches("CooperativeExec").count(),
398 2,
399 "Should have 2 CooperativeExec: one wrapping scan, one wrapping exch1"
400 );
401
402 assert_snapshot!(display, @r"
403 filter2
404 exch2
405 filter1
406 CooperativeExec
407 exch1
408 CooperativeExec
409 DataSourceExec: partitions=1, partition_sizes=[1]
410 ");
411 }
412}