Skip to main content

datafusion_distributed/distributed_planner/
task_estimator.rs

1use crate::DistributedConfig;
2use crate::config_extension_ext::set_distributed_option_extension;
3use crate::execution_plans::DistributedLeafExec;
4use TaskCountAnnotation::*;
5use datafusion::catalog::memory::DataSourceExec;
6use datafusion::config::ConfigOptions;
7use datafusion::datasource::physical_plan::{FileGroup, FileGroupPartitioner, FileScanConfig};
8use datafusion::error::Result;
9use datafusion::execution::TaskContext;
10use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties};
11use datafusion::prelude::SessionConfig;
12use delegate::delegate;
13use std::fmt::Debug;
14use std::sync::Arc;
15use url::Url;
16
17/// Annotation attached to a single [ExecutionPlan] that determines how many distributed tasks
18/// it should run on.
19#[derive(Debug, Clone, Copy)]
20pub enum TaskCountAnnotation {
21    /// The desired number of distributed tasks for this node. The final task count for the
22    /// annotated node might not be exactly this number, it is more like a hint, so depending
23    /// on the desired task count of adjacent nodes, the final task count might change.
24    Desired(usize),
25    /// Sets a maximum number of distributed tasks for this node. Typically used with the inner
26    /// value of 1, stating that this node cannot be executed in a distributed fashion.
27    Maximum(usize),
28}
29
30impl From<TaskCountAnnotation> for usize {
31    fn from(annotation: TaskCountAnnotation) -> Self {
32        annotation.as_usize()
33    }
34}
35
36impl TaskCountAnnotation {
37    pub fn as_usize(&self) -> usize {
38        match self {
39            Desired(desired) => *desired,
40            Maximum(maximum) => *maximum,
41        }
42    }
43
44    pub(crate) fn limit(self, limit: usize) -> Self {
45        match self {
46            Desired(desired) => Desired(desired.min(limit)),
47            Maximum(maximum) => Maximum(maximum.min(limit)),
48        }
49    }
50
51    pub(crate) fn merge(self, other: TaskCountAnnotation) -> Self {
52        match (self, other) {
53            (Desired(a), Desired(b)) => Desired(std::cmp::max(a, b)),
54            (Desired(_), Maximum(b)) => Maximum(b),
55            (Maximum(a), Desired(_)) => Maximum(a),
56            (Maximum(a), Maximum(b)) => Maximum(std::cmp::min(a, b)),
57        }
58    }
59}
60
61/// Result of running a [TaskEstimator] on a leaf node. It tells the distributed planner hints
62/// about how many tasks should be used in [Stage]s that contain leaf nodes.
63pub struct TaskEstimation {
64    /// The number of tasks that should be used in the [Stage] containing the leaf node.
65    ///
66    /// Even if implementations get to decide this number, there are situations where it can
67    /// get overridden:
68    /// - If a [Stage] contains multiple leaf nodes, the one that declares the biggest
69    ///   task_count wins.
70    /// - If there are less available workers than this number, the number of available workers
71    ///   is chosen.
72    pub task_count: TaskCountAnnotation,
73}
74
75impl TaskEstimation {
76    /// Tells the distributed planner that the evaluated stage can have **at maximum** the provided
77    /// number of tasks, setting a hard upper limit.
78    ///
79    /// Returning `TaskEstimation::maximum(1)` tells the distributed planner that the evaluated
80    /// stage cannot be distributed.
81    ///
82    /// Even if a `TaskEstimation::maximum(N)` is provided, any other node in the same stage
83    /// providing a value of `TaskEstimation::maximum(M)` where `M` < `N` will have preference.
84    pub fn maximum(value: usize) -> Self {
85        TaskEstimation {
86            task_count: TaskCountAnnotation::Maximum(value),
87        }
88    }
89
90    /// Tells the distributed planner that the evaluated can **optimally** have the provided
91    /// number of tasks, setting a soft task count hint that can be overridden by others.
92    ///
93    /// The provided `TaskEstimation::desired(N)` can be overridden by:
94    /// - Other nodes providing a `TaskEstimation::desired(M)` where `M` > `N`.
95    /// - Any other node providing a `TaskEstimation::maximum(M)` where `M` can be anything.
96    pub fn desired(value: usize) -> Self {
97        TaskEstimation {
98            task_count: TaskCountAnnotation::Desired(value),
99        }
100    }
101}
102
103/// Given a leaf node, provides an estimation about how many tasks should be used in the
104/// stage containing it, and if the leaf node should be replaced by some other.
105///
106/// The distributed planner will try many [TaskEstimator]s in order until one provides an
107/// estimation for a specific leaf node. Once that's done, upper stages will get their task
108/// count calculated based on whether lower stages are reducing the cardinality of the data
109/// or increasing it.
110pub trait TaskEstimator {
111    /// Function applied to each node that returns a [TaskEstimation] hinting how many
112    /// tasks should be used in the [Stage] containing that node.
113    ///
114    /// All the [TaskEstimator] registered in the session will be applied to the node
115    /// until one returns an estimation.
116    ///
117    ///
118    /// If no estimation is returned from any of the registered [TaskEstimator]s, then:
119    /// - If the node is a leaf node,`Maximum(1)` is assumed, hinting the distributed planner
120    ///   that the leaf node cannot be distributed across tasks.
121    /// - If the node is a normal node in the plan, then the maximum task count from its children
122    ///   is inherited.
123    fn task_estimation(
124        &self,
125        plan: &Arc<dyn ExecutionPlan>,
126        cfg: &ConfigOptions,
127    ) -> Option<TaskEstimation>;
128
129    /// After a final task_count is decided, taking into account all the leaf nodes in the [Stage],
130    /// this allows performing a transformation in the leaf nodes for accounting for the fact that
131    /// they are going to run in multiple tasks.
132    fn scale_up_leaf_node(
133        &self,
134        plan: &Arc<dyn ExecutionPlan>,
135        task_count: usize,
136        cfg: &ConfigOptions,
137    ) -> Result<Option<Arc<dyn ExecutionPlan>>>;
138
139    /// Optionally defines a custom protocol for routing tasks to specific worker URLs. Receives
140    /// routing context including task count and a list of available URLs, and returns a vector
141    /// of routed URLs, in order of task assignment.
142    ///
143    /// If Ok(Some(Vec<Url>)) is returned, tasks are sent in order to the URLs specified in the
144    /// returned vector. If Ok(None) is returned, execution defaults to round-robin routing.
145    fn route_tasks(&self, _routing_ctx: &TaskRoutingContext<'_>) -> Result<Option<Vec<Url>>> {
146        Ok(None)
147    }
148}
149
150/// Context usable for routing tasks to worker URLs.
151pub struct TaskRoutingContext<'a> {
152    /// The task context active at routing time.
153    pub task_ctx: Arc<TaskContext>,
154    /// The head execution plan of the stage being routed.
155    pub plan: &'a Arc<dyn ExecutionPlan>,
156    /// The number of tasks to be assigned.
157    pub task_count: usize,
158    /// Contains a list of URLs representing machines available to receive a task. These URLs are
159    /// sourced at execution time and thus should closely reflect the real state of the cluster.
160    pub available_urls: &'a [Url],
161}
162
163impl TaskEstimator for usize {
164    fn task_estimation(
165        &self,
166        inputs: &Arc<dyn ExecutionPlan>,
167        _: &ConfigOptions,
168    ) -> Option<TaskEstimation> {
169        if inputs.children().is_empty() {
170            Some(TaskEstimation {
171                task_count: TaskCountAnnotation::Desired(*self),
172            })
173        } else {
174            None
175        }
176    }
177
178    fn scale_up_leaf_node(
179        &self,
180        _: &Arc<dyn ExecutionPlan>,
181        _: usize,
182        _: &ConfigOptions,
183    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
184        Ok(None)
185    }
186}
187
188impl TaskEstimator for Arc<dyn TaskEstimator> {
189    delegate! {
190        to self.as_ref() {
191            fn task_estimation(&self, plan: &Arc<dyn ExecutionPlan>, cfg: &ConfigOptions) -> Option<TaskEstimation>;
192            fn scale_up_leaf_node(&self, plan: &Arc<dyn ExecutionPlan>, task_count: usize, cfg: &ConfigOptions) -> Result<Option<Arc<dyn ExecutionPlan>>>;
193            fn route_tasks(&self, routing_ctx: &TaskRoutingContext<'_>) -> Result<Option<Vec<Url>>>;
194        }
195    }
196}
197
198impl TaskEstimator for Arc<dyn TaskEstimator + Send + Sync> {
199    delegate! {
200        to self.as_ref() {
201            fn task_estimation(&self, plan: &Arc<dyn ExecutionPlan>, cfg: &ConfigOptions) -> Option<TaskEstimation>;
202            fn scale_up_leaf_node(&self, plan: &Arc<dyn ExecutionPlan>, task_count: usize, cfg: &ConfigOptions) -> Result<Option<Arc<dyn ExecutionPlan>>>;
203            fn route_tasks(&self, routing_ctx: &TaskRoutingContext<'_>) -> Result<Option<Vec<Url>>>;
204        }
205    }
206}
207
208pub(crate) fn set_distributed_task_estimator(
209    cfg: &mut SessionConfig,
210    estimator: impl TaskEstimator + Send + Sync + 'static,
211) {
212    let opts = cfg.options_mut();
213    if let Some(distributed_cfg) = opts.extensions.get_mut::<DistributedConfig>() {
214        distributed_cfg
215            .__private_task_estimator
216            .user_provided
217            .push(Arc::new(estimator));
218    } else {
219        let mut estimators = CombinedTaskEstimator::default();
220        estimators.user_provided.push(Arc::new(estimator));
221        set_distributed_option_extension(
222            cfg,
223            DistributedConfig {
224                __private_task_estimator: estimators,
225                ..Default::default()
226            },
227        )
228    }
229}
230
231/// [TaskEstimator] implementation that acts on [DataSourceExec] nodes that contain
232/// [FileScanConfig]s data sources (e.g., Parquet or CSV files). It reads the
233/// [DistributedConfig].`file_scan_config_bytes_per_partition` field and assigns as many tasks as
234/// needed so that no partition scans more than the configured number of bytes.
235#[derive(Debug)]
236pub(crate) struct FileScanConfigTaskEstimator;
237
238impl TaskEstimator for FileScanConfigTaskEstimator {
239    fn task_estimation(
240        &self,
241        plan: &Arc<dyn ExecutionPlan>,
242        cfg: &ConfigOptions,
243    ) -> Option<TaskEstimation> {
244        let dse: &DataSourceExec = plan.downcast_ref()?;
245        let file_scan: &FileScanConfig = dse.data_source().downcast_ref()?;
246
247        let d_cfg = cfg.extensions.get::<DistributedConfig>()?;
248
249        let mut total_bytes = 0;
250        for file_group in &file_scan.file_groups {
251            for file in file_group.files() {
252                total_bytes += file.effective_size() as usize
253            }
254        }
255
256        let task_count = total_bytes
257            .div_ceil(d_cfg.file_scan_config_bytes_per_partition)
258            .div_ceil(cfg.execution.target_partitions);
259
260        Some(TaskEstimation::desired(task_count))
261    }
262
263    fn scale_up_leaf_node(
264        &self,
265        plan: &Arc<dyn ExecutionPlan>,
266        task_count: usize,
267        _cfg: &ConfigOptions,
268    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
269        let Some(dse) = plan.downcast_ref::<DataSourceExec>() else {
270            return Ok(None);
271        };
272        let Some(file_scan) = dse.data_source().downcast_ref::<FileScanConfig>() else {
273            return Ok(None);
274        };
275        let partition_count = plan.output_partitioning().partition_count();
276
277        let rebalanced = if file_scan.partitioned_by_file_group {
278            let all_partitioned_files = file_scan
279                .file_groups
280                .iter()
281                .flat_map(|file_group| file_group.iter().cloned())
282                .collect::<Vec<_>>();
283            rebalance_round_robin(all_partitioned_files, partition_count * task_count)
284                .into_iter()
285                .map(FileGroup::new)
286                .collect::<Vec<_>>()
287        } else {
288            FileGroupPartitioner::new()
289                .with_target_partitions(partition_count * task_count)
290                // Allow repartitioning beyond normal limits, putting the limit in
291                // `partition_count * task_count` target partitions, and not in the
292                // resulting size.
293                .with_repartition_file_min_size(0)
294                .with_preserve_order_within_groups(!file_scan.output_ordering.is_empty())
295                .repartition_file_groups(&file_scan.file_groups)
296                .unwrap_or_else(|| file_scan.file_groups.clone())
297                .into_iter()
298                .collect()
299        };
300
301        let mut file_scan_template = file_scan.clone();
302        file_scan_template.file_groups.clear();
303        let mut file_scans = vec![file_scan_template; task_count];
304        for (i, file_group) in rebalanced.into_iter().enumerate() {
305            file_scans[i % task_count].file_groups.push(file_group);
306        }
307
308        let dle = DistributedLeafExec::try_new(
309            Arc::clone(plan),
310            file_scans
311                .into_iter()
312                .map(|file_scan| DataSourceExec::from_data_source(file_scan) as _),
313        )?;
314
315        Ok(Some(Arc::new(dle)))
316    }
317}
318
319fn rebalance_round_robin<T>(items: Vec<T>, target_groups: usize) -> Vec<Vec<T>> {
320    let mut groups = (0..target_groups)
321        .map(|_| Vec::new())
322        .collect::<Vec<Vec<T>>>();
323    for (idx, item) in items.into_iter().enumerate() {
324        groups[idx % target_groups].push(item);
325    }
326    groups
327}
328
329/// Tries multiple user-provided [TaskEstimator]s until one returns an estimation. If none
330/// returns an estimation, a set of default [TaskEstimation] implementations is tried. Right
331/// now the only default [TaskEstimation] is [FileScanConfigTaskEstimator].
332#[derive(Clone, Default)]
333pub(crate) struct CombinedTaskEstimator {
334    pub(crate) user_provided: Vec<Arc<dyn TaskEstimator + Send + Sync>>,
335}
336
337impl TaskEstimator for CombinedTaskEstimator {
338    fn task_estimation(
339        &self,
340        plan: &Arc<dyn ExecutionPlan>,
341        cfg: &ConfigOptions,
342    ) -> Option<TaskEstimation> {
343        for estimator in &self.user_provided {
344            if let Some(result) = estimator.task_estimation(plan, cfg) {
345                return Some(result);
346            }
347        }
348        // We want to execute the default estimators last so that the user-provided ones have
349        // a chance of providing an estimation.
350        // If none of the user-provided returned an estimation, the default ones are used.
351        for default_estimator in [&FileScanConfigTaskEstimator as &dyn TaskEstimator] {
352            if let Some(result) = default_estimator.task_estimation(plan, cfg) {
353                return Some(result);
354            }
355        }
356        None
357    }
358
359    fn scale_up_leaf_node(
360        &self,
361        plan: &Arc<dyn ExecutionPlan>,
362        task_count: usize,
363        cfg: &ConfigOptions,
364    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
365        for estimator in &self.user_provided {
366            if let Some(result) = estimator.scale_up_leaf_node(plan, task_count, cfg)? {
367                return Ok(Some(result));
368            }
369        }
370        // We want to execute the default estimators last so that the user-provided ones have
371        // a chance of providing an estimation.
372        // If none of the user-provided returned an estimation, the default ones are used.
373        for default_estimator in [&FileScanConfigTaskEstimator as &dyn TaskEstimator] {
374            if let Some(result) = default_estimator.scale_up_leaf_node(plan, task_count, cfg)? {
375                return Ok(Some(result));
376            }
377        }
378        Ok(None)
379    }
380
381    fn route_tasks(&self, routing_ctx: &TaskRoutingContext<'_>) -> Result<Option<Vec<Url>>> {
382        for estimator in &self.user_provided {
383            if let Some(result) = estimator.route_tasks(routing_ctx)? {
384                return Ok(Some(result));
385            }
386        }
387        Ok(None)
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394    use crate::networking::WorkerResolverExtension;
395    use crate::test_utils::in_memory_channel_resolver::InMemoryWorkerResolver;
396    use crate::test_utils::parquet::register_parquet_tables;
397    use datafusion::error::DataFusionError;
398    use datafusion::prelude::SessionContext;
399
400    #[tokio::test]
401    async fn test_first_user_estimator_wins() -> Result<(), DataFusionError> {
402        let mut combined = CombinedTaskEstimator::default();
403        combined.push(10);
404        combined.push(20);
405
406        let node = make_data_source_exec().await?;
407        assert_eq!(combined.task_count(node, |cfg| cfg), 10);
408        Ok(())
409    }
410
411    #[tokio::test]
412    async fn test_continues_until_some() -> Result<(), DataFusionError> {
413        let mut combined = CombinedTaskEstimator::default();
414        combined.push(|_: &Arc<dyn ExecutionPlan>, _: &ConfigOptions| None);
415        combined.push(30);
416
417        let node = make_data_source_exec().await?;
418        assert_eq!(combined.task_count(node, |cfg| cfg), 30);
419        Ok(())
420    }
421
422    #[tokio::test]
423    async fn test_defaults_to_file_scan_config_task_estimator() -> Result<(), DataFusionError> {
424        let mut combined = CombinedTaskEstimator::default();
425        combined.push(|_: &Arc<dyn ExecutionPlan>, _: &ConfigOptions| None);
426
427        // No user estimator returns a value, so the default FileScanConfigTaskEstimator kicks in.
428        // Size the per-partition budget (with target_partitions pinned to 1) so the scan splits
429        // into exactly 3 partitions.
430        let node = make_data_source_exec().await?;
431        let bytes_per_partition = total_scan_bytes(&node).div_ceil(3);
432        let task_count = combined.task_count(node, |mut cfg| {
433            cfg.file_scan_config_bytes_per_partition = bytes_per_partition;
434            cfg
435        });
436        assert_eq!(task_count, 3);
437        Ok(())
438    }
439
440    fn total_scan_bytes(node: &Arc<dyn ExecutionPlan>) -> usize {
441        let dse = node.downcast_ref::<DataSourceExec>().unwrap();
442        let file_scan = dse.data_source().downcast_ref::<FileScanConfig>().unwrap();
443        file_scan
444            .file_groups
445            .iter()
446            .flat_map(|file_group| file_group.files())
447            .map(|file| file.effective_size() as usize)
448            .sum()
449    }
450
451    #[test]
452    fn test_rebalance_round_robin_fixes_group_boundary_skew() {
453        let items = (0..8).collect::<Vec<_>>();
454        let groups = rebalance_round_robin(items, 5);
455        let sizes = groups.iter().map(Vec::len).collect::<Vec<_>>();
456        assert_eq!(sizes, vec![2, 2, 2, 1, 1]);
457    }
458
459    #[test]
460    fn test_rebalance_round_robin_pads_with_empty_groups() {
461        // With fewer items than target groups, the extra groups are kept empty rather than
462        // dropped. This guarantees every task ends up with the same number of partitions.
463        let items = vec![10, 20, 30];
464        let groups = rebalance_round_robin(items, 5);
465        let sizes = groups.iter().map(Vec::len).collect::<Vec<_>>();
466        assert_eq!(sizes, vec![1, 1, 1, 0, 0]);
467    }
468
469    impl CombinedTaskEstimator {
470        fn push(&mut self, value: impl TaskEstimator + Send + Sync + 'static) {
471            self.user_provided.push(Arc::new(value));
472        }
473
474        fn task_count(
475            &self,
476            node: Arc<dyn ExecutionPlan>,
477            f: impl FnOnce(DistributedConfig) -> DistributedConfig,
478        ) -> usize {
479            let mut cfg = ConfigOptions::default();
480            // Pin target_partitions so the byte-based estimation is deterministic regardless of
481            // the host's core count.
482            cfg.execution.target_partitions = 1;
483            let d_cfg = DistributedConfig {
484                file_scan_config_bytes_per_partition: 1,
485                __private_worker_resolver: WorkerResolverExtension(Arc::new(
486                    InMemoryWorkerResolver::new(3),
487                )),
488                ..Default::default()
489            };
490            cfg.extensions.insert(f(d_cfg));
491            self.task_estimation(&node, &cfg)
492                .unwrap()
493                .task_count
494                .as_usize()
495        }
496    }
497
498    async fn make_data_source_exec() -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
499        let ctx = SessionContext::new();
500        register_parquet_tables(&ctx).await?;
501        let mut plan = ctx
502            .sql("SELECT * FROM weather")
503            .await?
504            .create_physical_plan()
505            .await?;
506        while !plan.children().is_empty() {
507            plan = Arc::clone(plan.children()[0])
508        }
509        Ok(plan)
510    }
511
512    impl<F: Fn(&Arc<dyn ExecutionPlan>, &ConfigOptions) -> Option<TaskEstimation>> TaskEstimator for F {
513        fn task_estimation(
514            &self,
515            plan: &Arc<dyn ExecutionPlan>,
516            cfg: &ConfigOptions,
517        ) -> Option<TaskEstimation> {
518            self(plan, cfg)
519        }
520
521        fn scale_up_leaf_node(
522            &self,
523            _plan: &Arc<dyn ExecutionPlan>,
524            _task_count: usize,
525            _cfg: &ConfigOptions,
526        ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
527            Ok(None)
528        }
529    }
530}