datafusion_distributed/distributed_planner/
task_estimator.rs1use crate::DistributedConfig;
2use crate::config_extension_ext::set_distributed_option_extension;
3use crate::execution_plans::DistributedLeafExec;
4use TaskCountAnnotation::*;
5use datafusion::catalog::memory::DataSourceExec;
6use datafusion::config::ConfigOptions;
7use datafusion::datasource::physical_plan::{FileGroup, FileGroupPartitioner, FileScanConfig};
8use datafusion::error::Result;
9use datafusion::execution::TaskContext;
10use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties};
11use datafusion::prelude::SessionConfig;
12use delegate::delegate;
13use std::fmt::Debug;
14use std::sync::Arc;
15use url::Url;
16
17#[derive(Debug, Clone, Copy)]
20pub enum TaskCountAnnotation {
21 Desired(usize),
25 Maximum(usize),
28}
29
30impl From<TaskCountAnnotation> for usize {
31 fn from(annotation: TaskCountAnnotation) -> Self {
32 annotation.as_usize()
33 }
34}
35
36impl TaskCountAnnotation {
37 pub fn as_usize(&self) -> usize {
38 match self {
39 Desired(desired) => *desired,
40 Maximum(maximum) => *maximum,
41 }
42 }
43
44 pub(crate) fn limit(self, limit: usize) -> Self {
45 match self {
46 Desired(desired) => Desired(desired.min(limit)),
47 Maximum(maximum) => Maximum(maximum.min(limit)),
48 }
49 }
50
51 pub(crate) fn merge(self, other: TaskCountAnnotation) -> Self {
52 match (self, other) {
53 (Desired(a), Desired(b)) => Desired(std::cmp::max(a, b)),
54 (Desired(_), Maximum(b)) => Maximum(b),
55 (Maximum(a), Desired(_)) => Maximum(a),
56 (Maximum(a), Maximum(b)) => Maximum(std::cmp::min(a, b)),
57 }
58 }
59}
60
61pub struct TaskEstimation {
64 pub task_count: TaskCountAnnotation,
73}
74
75impl TaskEstimation {
76 pub fn maximum(value: usize) -> Self {
85 TaskEstimation {
86 task_count: TaskCountAnnotation::Maximum(value),
87 }
88 }
89
90 pub fn desired(value: usize) -> Self {
97 TaskEstimation {
98 task_count: TaskCountAnnotation::Desired(value),
99 }
100 }
101}
102
103pub trait TaskEstimator {
111 fn task_estimation(
124 &self,
125 plan: &Arc<dyn ExecutionPlan>,
126 cfg: &ConfigOptions,
127 ) -> Option<TaskEstimation>;
128
129 fn scale_up_leaf_node(
133 &self,
134 plan: &Arc<dyn ExecutionPlan>,
135 task_count: usize,
136 cfg: &ConfigOptions,
137 ) -> Result<Option<Arc<dyn ExecutionPlan>>>;
138
139 fn route_tasks(&self, _routing_ctx: &TaskRoutingContext<'_>) -> Result<Option<Vec<Url>>> {
146 Ok(None)
147 }
148}
149
150pub struct TaskRoutingContext<'a> {
152 pub task_ctx: Arc<TaskContext>,
154 pub plan: &'a Arc<dyn ExecutionPlan>,
156 pub task_count: usize,
158 pub available_urls: &'a [Url],
161}
162
163impl TaskEstimator for usize {
164 fn task_estimation(
165 &self,
166 inputs: &Arc<dyn ExecutionPlan>,
167 _: &ConfigOptions,
168 ) -> Option<TaskEstimation> {
169 if inputs.children().is_empty() {
170 Some(TaskEstimation {
171 task_count: TaskCountAnnotation::Desired(*self),
172 })
173 } else {
174 None
175 }
176 }
177
178 fn scale_up_leaf_node(
179 &self,
180 _: &Arc<dyn ExecutionPlan>,
181 _: usize,
182 _: &ConfigOptions,
183 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
184 Ok(None)
185 }
186}
187
188impl TaskEstimator for Arc<dyn TaskEstimator> {
189 delegate! {
190 to self.as_ref() {
191 fn task_estimation(&self, plan: &Arc<dyn ExecutionPlan>, cfg: &ConfigOptions) -> Option<TaskEstimation>;
192 fn scale_up_leaf_node(&self, plan: &Arc<dyn ExecutionPlan>, task_count: usize, cfg: &ConfigOptions) -> Result<Option<Arc<dyn ExecutionPlan>>>;
193 fn route_tasks(&self, routing_ctx: &TaskRoutingContext<'_>) -> Result<Option<Vec<Url>>>;
194 }
195 }
196}
197
198impl TaskEstimator for Arc<dyn TaskEstimator + Send + Sync> {
199 delegate! {
200 to self.as_ref() {
201 fn task_estimation(&self, plan: &Arc<dyn ExecutionPlan>, cfg: &ConfigOptions) -> Option<TaskEstimation>;
202 fn scale_up_leaf_node(&self, plan: &Arc<dyn ExecutionPlan>, task_count: usize, cfg: &ConfigOptions) -> Result<Option<Arc<dyn ExecutionPlan>>>;
203 fn route_tasks(&self, routing_ctx: &TaskRoutingContext<'_>) -> Result<Option<Vec<Url>>>;
204 }
205 }
206}
207
208pub(crate) fn set_distributed_task_estimator(
209 cfg: &mut SessionConfig,
210 estimator: impl TaskEstimator + Send + Sync + 'static,
211) {
212 let opts = cfg.options_mut();
213 if let Some(distributed_cfg) = opts.extensions.get_mut::<DistributedConfig>() {
214 distributed_cfg
215 .__private_task_estimator
216 .user_provided
217 .push(Arc::new(estimator));
218 } else {
219 let mut estimators = CombinedTaskEstimator::default();
220 estimators.user_provided.push(Arc::new(estimator));
221 set_distributed_option_extension(
222 cfg,
223 DistributedConfig {
224 __private_task_estimator: estimators,
225 ..Default::default()
226 },
227 )
228 }
229}
230
231#[derive(Debug)]
236pub(crate) struct FileScanConfigTaskEstimator;
237
238impl TaskEstimator for FileScanConfigTaskEstimator {
239 fn task_estimation(
240 &self,
241 plan: &Arc<dyn ExecutionPlan>,
242 cfg: &ConfigOptions,
243 ) -> Option<TaskEstimation> {
244 let dse: &DataSourceExec = plan.downcast_ref()?;
245 let file_scan: &FileScanConfig = dse.data_source().downcast_ref()?;
246
247 let d_cfg = cfg.extensions.get::<DistributedConfig>()?;
248
249 let mut total_bytes = 0;
250 for file_group in &file_scan.file_groups {
251 for file in file_group.files() {
252 total_bytes += file.effective_size() as usize
253 }
254 }
255
256 let task_count = total_bytes
257 .div_ceil(d_cfg.file_scan_config_bytes_per_partition)
258 .div_ceil(cfg.execution.target_partitions);
259
260 Some(TaskEstimation::desired(task_count))
261 }
262
263 fn scale_up_leaf_node(
264 &self,
265 plan: &Arc<dyn ExecutionPlan>,
266 task_count: usize,
267 _cfg: &ConfigOptions,
268 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
269 let Some(dse) = plan.downcast_ref::<DataSourceExec>() else {
270 return Ok(None);
271 };
272 let Some(file_scan) = dse.data_source().downcast_ref::<FileScanConfig>() else {
273 return Ok(None);
274 };
275 let partition_count = plan.output_partitioning().partition_count();
276
277 let rebalanced = if file_scan.partitioned_by_file_group {
278 let all_partitioned_files = file_scan
279 .file_groups
280 .iter()
281 .flat_map(|file_group| file_group.iter().cloned())
282 .collect::<Vec<_>>();
283 rebalance_round_robin(all_partitioned_files, partition_count * task_count)
284 .into_iter()
285 .map(FileGroup::new)
286 .collect::<Vec<_>>()
287 } else {
288 FileGroupPartitioner::new()
289 .with_target_partitions(partition_count * task_count)
290 .with_repartition_file_min_size(0)
294 .with_preserve_order_within_groups(!file_scan.output_ordering.is_empty())
295 .repartition_file_groups(&file_scan.file_groups)
296 .unwrap_or_else(|| file_scan.file_groups.clone())
297 .into_iter()
298 .collect()
299 };
300
301 let mut file_scan_template = file_scan.clone();
302 file_scan_template.file_groups.clear();
303 let mut file_scans = vec![file_scan_template; task_count];
304 for (i, file_group) in rebalanced.into_iter().enumerate() {
305 file_scans[i % task_count].file_groups.push(file_group);
306 }
307
308 let dle = DistributedLeafExec::try_new(
309 Arc::clone(plan),
310 file_scans
311 .into_iter()
312 .map(|file_scan| DataSourceExec::from_data_source(file_scan) as _),
313 )?;
314
315 Ok(Some(Arc::new(dle)))
316 }
317}
318
319fn rebalance_round_robin<T>(items: Vec<T>, target_groups: usize) -> Vec<Vec<T>> {
320 let mut groups = (0..target_groups)
321 .map(|_| Vec::new())
322 .collect::<Vec<Vec<T>>>();
323 for (idx, item) in items.into_iter().enumerate() {
324 groups[idx % target_groups].push(item);
325 }
326 groups
327}
328
329#[derive(Clone, Default)]
333pub(crate) struct CombinedTaskEstimator {
334 pub(crate) user_provided: Vec<Arc<dyn TaskEstimator + Send + Sync>>,
335}
336
337impl TaskEstimator for CombinedTaskEstimator {
338 fn task_estimation(
339 &self,
340 plan: &Arc<dyn ExecutionPlan>,
341 cfg: &ConfigOptions,
342 ) -> Option<TaskEstimation> {
343 for estimator in &self.user_provided {
344 if let Some(result) = estimator.task_estimation(plan, cfg) {
345 return Some(result);
346 }
347 }
348 for default_estimator in [&FileScanConfigTaskEstimator as &dyn TaskEstimator] {
352 if let Some(result) = default_estimator.task_estimation(plan, cfg) {
353 return Some(result);
354 }
355 }
356 None
357 }
358
359 fn scale_up_leaf_node(
360 &self,
361 plan: &Arc<dyn ExecutionPlan>,
362 task_count: usize,
363 cfg: &ConfigOptions,
364 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
365 for estimator in &self.user_provided {
366 if let Some(result) = estimator.scale_up_leaf_node(plan, task_count, cfg)? {
367 return Ok(Some(result));
368 }
369 }
370 for default_estimator in [&FileScanConfigTaskEstimator as &dyn TaskEstimator] {
374 if let Some(result) = default_estimator.scale_up_leaf_node(plan, task_count, cfg)? {
375 return Ok(Some(result));
376 }
377 }
378 Ok(None)
379 }
380
381 fn route_tasks(&self, routing_ctx: &TaskRoutingContext<'_>) -> Result<Option<Vec<Url>>> {
382 for estimator in &self.user_provided {
383 if let Some(result) = estimator.route_tasks(routing_ctx)? {
384 return Ok(Some(result));
385 }
386 }
387 Ok(None)
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394 use crate::networking::WorkerResolverExtension;
395 use crate::test_utils::in_memory_channel_resolver::InMemoryWorkerResolver;
396 use crate::test_utils::parquet::register_parquet_tables;
397 use datafusion::error::DataFusionError;
398 use datafusion::prelude::SessionContext;
399
400 #[tokio::test]
401 async fn test_first_user_estimator_wins() -> Result<(), DataFusionError> {
402 let mut combined = CombinedTaskEstimator::default();
403 combined.push(10);
404 combined.push(20);
405
406 let node = make_data_source_exec().await?;
407 assert_eq!(combined.task_count(node, |cfg| cfg), 10);
408 Ok(())
409 }
410
411 #[tokio::test]
412 async fn test_continues_until_some() -> Result<(), DataFusionError> {
413 let mut combined = CombinedTaskEstimator::default();
414 combined.push(|_: &Arc<dyn ExecutionPlan>, _: &ConfigOptions| None);
415 combined.push(30);
416
417 let node = make_data_source_exec().await?;
418 assert_eq!(combined.task_count(node, |cfg| cfg), 30);
419 Ok(())
420 }
421
422 #[tokio::test]
423 async fn test_defaults_to_file_scan_config_task_estimator() -> Result<(), DataFusionError> {
424 let mut combined = CombinedTaskEstimator::default();
425 combined.push(|_: &Arc<dyn ExecutionPlan>, _: &ConfigOptions| None);
426
427 let node = make_data_source_exec().await?;
431 let bytes_per_partition = total_scan_bytes(&node).div_ceil(3);
432 let task_count = combined.task_count(node, |mut cfg| {
433 cfg.file_scan_config_bytes_per_partition = bytes_per_partition;
434 cfg
435 });
436 assert_eq!(task_count, 3);
437 Ok(())
438 }
439
440 fn total_scan_bytes(node: &Arc<dyn ExecutionPlan>) -> usize {
441 let dse = node.downcast_ref::<DataSourceExec>().unwrap();
442 let file_scan = dse.data_source().downcast_ref::<FileScanConfig>().unwrap();
443 file_scan
444 .file_groups
445 .iter()
446 .flat_map(|file_group| file_group.files())
447 .map(|file| file.effective_size() as usize)
448 .sum()
449 }
450
451 #[test]
452 fn test_rebalance_round_robin_fixes_group_boundary_skew() {
453 let items = (0..8).collect::<Vec<_>>();
454 let groups = rebalance_round_robin(items, 5);
455 let sizes = groups.iter().map(Vec::len).collect::<Vec<_>>();
456 assert_eq!(sizes, vec![2, 2, 2, 1, 1]);
457 }
458
459 #[test]
460 fn test_rebalance_round_robin_pads_with_empty_groups() {
461 let items = vec![10, 20, 30];
464 let groups = rebalance_round_robin(items, 5);
465 let sizes = groups.iter().map(Vec::len).collect::<Vec<_>>();
466 assert_eq!(sizes, vec![1, 1, 1, 0, 0]);
467 }
468
469 impl CombinedTaskEstimator {
470 fn push(&mut self, value: impl TaskEstimator + Send + Sync + 'static) {
471 self.user_provided.push(Arc::new(value));
472 }
473
474 fn task_count(
475 &self,
476 node: Arc<dyn ExecutionPlan>,
477 f: impl FnOnce(DistributedConfig) -> DistributedConfig,
478 ) -> usize {
479 let mut cfg = ConfigOptions::default();
480 cfg.execution.target_partitions = 1;
483 let d_cfg = DistributedConfig {
484 file_scan_config_bytes_per_partition: 1,
485 __private_worker_resolver: WorkerResolverExtension(Arc::new(
486 InMemoryWorkerResolver::new(3),
487 )),
488 ..Default::default()
489 };
490 cfg.extensions.insert(f(d_cfg));
491 self.task_estimation(&node, &cfg)
492 .unwrap()
493 .task_count
494 .as_usize()
495 }
496 }
497
498 async fn make_data_source_exec() -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
499 let ctx = SessionContext::new();
500 register_parquet_tables(&ctx).await?;
501 let mut plan = ctx
502 .sql("SELECT * FROM weather")
503 .await?
504 .create_physical_plan()
505 .await?;
506 while !plan.children().is_empty() {
507 plan = Arc::clone(plan.children()[0])
508 }
509 Ok(plan)
510 }
511
512 impl<F: Fn(&Arc<dyn ExecutionPlan>, &ConfigOptions) -> Option<TaskEstimation>> TaskEstimator for F {
513 fn task_estimation(
514 &self,
515 plan: &Arc<dyn ExecutionPlan>,
516 cfg: &ConfigOptions,
517 ) -> Option<TaskEstimation> {
518 self(plan, cfg)
519 }
520
521 fn scale_up_leaf_node(
522 &self,
523 _plan: &Arc<dyn ExecutionPlan>,
524 _task_count: usize,
525 _cfg: &ConfigOptions,
526 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
527 Ok(None)
528 }
529 }
530}