Skip to main content

datafusion_distributed/distributed_planner/
task_estimator.rs

1use crate::config_extension_ext::set_distributed_option_extension;
2use crate::{DistributedConfig, PartitionIsolatorExec};
3use datafusion::catalog::memory::DataSourceExec;
4use datafusion::config::ConfigOptions;
5use datafusion::datasource::physical_plan::FileScanConfig;
6use datafusion::physical_plan::ExecutionPlan;
7use datafusion::prelude::SessionConfig;
8use delegate::delegate;
9use std::fmt::Debug;
10use std::sync::Arc;
11
12/// Annotation attached to a single [ExecutionPlan] that determines how many distributed tasks
13/// it should run on.
14#[derive(Debug, Clone)]
15pub enum TaskCountAnnotation {
16    /// The desired number of distributed tasks for this node. The final task count for the
17    /// annotated node might not be exactly this number, it is more like a hint, so depending
18    /// on the desired task count of adjacent nodes, the final task count might change.
19    Desired(usize),
20    /// Sets a maximum number of distributed tasks for this node. Typically used with the inner
21    /// value of 1, stating that this node cannot be executed in a distributed fashion.
22    Maximum(usize),
23}
24
25impl From<TaskCountAnnotation> for usize {
26    fn from(annotation: TaskCountAnnotation) -> Self {
27        annotation.as_usize()
28    }
29}
30
31impl TaskCountAnnotation {
32    pub fn as_usize(&self) -> usize {
33        match self {
34            Self::Desired(desired) => *desired,
35            Self::Maximum(maximum) => *maximum,
36        }
37    }
38
39    pub(crate) fn limit(self, limit: usize) -> Self {
40        match self {
41            Self::Desired(desired) => Self::Desired(desired.min(limit)),
42            Self::Maximum(maximum) => Self::Maximum(maximum.min(limit)),
43        }
44    }
45}
46
47/// Result of running a [TaskEstimator] on a leaf node. It tells the distributed planner hints
48/// about how many tasks should be used in [Stage]s that contain leaf nodes.
49pub struct TaskEstimation {
50    /// The number of tasks that should be used in the [Stage] containing the leaf node.
51    ///
52    /// Even if implementations get to decide this number, there are situations where it can
53    /// get overridden:
54    /// - If a [Stage] contains multiple leaf nodes, the one that declares the biggest
55    ///   task_count wins.
56    /// - If there are less available workers than this number, the number of available workers
57    ///   is chosen.
58    pub task_count: TaskCountAnnotation,
59}
60
61impl TaskEstimation {
62    /// Tells the distributed planner that the evaluated stage can have **at maximum** the provided
63    /// number of tasks, setting a hard upper limit.
64    ///
65    /// Returning `TaskEstimation::maximum(1)` tells the distributed planner that the evaluated
66    /// stage cannot be distributed.
67    ///
68    /// Even if a `TaskEstimation::maximum(N)` is provided, any other node in the same stage
69    /// providing a value of `TaskEstimation::maximum(M)` where `M` < `N` will have preference.
70    pub fn maximum(value: usize) -> Self {
71        TaskEstimation {
72            task_count: TaskCountAnnotation::Maximum(value),
73        }
74    }
75
76    /// Tells the distributed planner that the evaluated can **optimally** have the provided
77    /// number of tasks, setting a soft task count hint that can be overridden by others.
78    ///
79    /// The provided `TaskEstimation::desired(N)` can be overridden by:
80    /// - Other nodes providing a `TaskEstimation::desired(M)` where `M` > `N`.
81    /// - Any other node providing a `TaskEstimation::maximum(M)` where `M` can be anything.
82    pub fn desired(value: usize) -> Self {
83        TaskEstimation {
84            task_count: TaskCountAnnotation::Desired(value),
85        }
86    }
87}
88
89/// Given a leaf node, provides an estimation about how many tasks should be used in the
90/// stage containing it, and if the leaf node should be replaced by some other.
91///
92/// The distributed planner will try many [TaskEstimator]s in order until one provides an
93/// estimation for a specific leaf node. Once that's done, upper stages will get their task
94/// count calculated based on whether lower stages are reducing the cardinality of the data
95/// or increasing it.
96pub trait TaskEstimator {
97    /// Function applied to each node that returns a [TaskEstimation] hinting how many
98    /// tasks should be used in the [Stage] containing that node.
99    ///
100    /// All the [TaskEstimator] registered in the session will be applied to the node
101    /// until one returns an estimation.
102    ///
103    ///
104    /// If no estimation is returned from any of the registered [TaskEstimator]s, then:
105    /// - If the node is a leaf node,`Maximum(1)` is assumed, hinting the distributed planner
106    ///   that the leaf node cannot be distributed across tasks.
107    /// - If the node is a normal node in the plan, then the maximum task count from its children
108    ///   is inherited.
109    fn task_estimation(
110        &self,
111        plan: &Arc<dyn ExecutionPlan>,
112        cfg: &ConfigOptions,
113    ) -> Option<TaskEstimation>;
114
115    /// After a final task_count is decided, taking into account all the leaf nodes in the [Stage],
116    /// this allows performing a transformation in the leaf nodes for accounting for the fact that
117    /// they are going to run in multiple tasks.
118    fn scale_up_leaf_node(
119        &self,
120        plan: &Arc<dyn ExecutionPlan>,
121        task_count: usize,
122        cfg: &ConfigOptions,
123    ) -> Option<Arc<dyn ExecutionPlan>>;
124}
125
126impl TaskEstimator for usize {
127    fn task_estimation(
128        &self,
129        inputs: &Arc<dyn ExecutionPlan>,
130        _: &ConfigOptions,
131    ) -> Option<TaskEstimation> {
132        if inputs.children().is_empty() {
133            Some(TaskEstimation {
134                task_count: TaskCountAnnotation::Desired(*self),
135            })
136        } else {
137            None
138        }
139    }
140
141    fn scale_up_leaf_node(
142        &self,
143        _: &Arc<dyn ExecutionPlan>,
144        _: usize,
145        _: &ConfigOptions,
146    ) -> Option<Arc<dyn ExecutionPlan>> {
147        None
148    }
149}
150
151impl TaskEstimator for Arc<dyn TaskEstimator> {
152    delegate! {
153        to self.as_ref() {
154            fn task_estimation(&self, plan: &Arc<dyn ExecutionPlan>, cfg: &ConfigOptions) -> Option<TaskEstimation>;
155            fn scale_up_leaf_node(&self, plan: &Arc<dyn ExecutionPlan>, task_count: usize, cfg: &ConfigOptions) -> Option<Arc<dyn ExecutionPlan>>;
156        }
157    }
158}
159
160impl TaskEstimator for Arc<dyn TaskEstimator + Send + Sync> {
161    delegate! {
162        to self.as_ref() {
163            fn task_estimation(&self, plan: &Arc<dyn ExecutionPlan>, cfg: &ConfigOptions) -> Option<TaskEstimation>;
164            fn scale_up_leaf_node(&self, plan: &Arc<dyn ExecutionPlan>, task_count: usize, cfg: &ConfigOptions) -> Option<Arc<dyn ExecutionPlan>>;
165        }
166    }
167}
168
169pub(crate) fn set_distributed_task_estimator(
170    cfg: &mut SessionConfig,
171    estimator: impl TaskEstimator + Send + Sync + 'static,
172) {
173    let opts = cfg.options_mut();
174    if let Some(distributed_cfg) = opts.extensions.get_mut::<DistributedConfig>() {
175        distributed_cfg
176            .__private_task_estimator
177            .user_provided
178            .push(Arc::new(estimator));
179    } else {
180        let mut estimators = CombinedTaskEstimator::default();
181        estimators.user_provided.push(Arc::new(estimator));
182        set_distributed_option_extension(
183            cfg,
184            DistributedConfig {
185                __private_task_estimator: estimators,
186                ..Default::default()
187            },
188        )
189    }
190}
191
192/// [TaskEstimator] implementation that acts on [DataSourceExec] nodes that contain
193/// [FileScanConfig]s data sources (e.g., Parquet or CSV files). it will read the
194/// [DistributedConfig].`files_per_task` field and assigns as many tasks as needed so that
195/// no task handles more than the configured files.
196#[derive(Debug)]
197struct FileScanConfigTaskEstimator;
198
199impl TaskEstimator for FileScanConfigTaskEstimator {
200    fn task_estimation(
201        &self,
202        plan: &Arc<dyn ExecutionPlan>,
203        cfg: &ConfigOptions,
204    ) -> Option<TaskEstimation> {
205        let dse: &DataSourceExec = plan.as_any().downcast_ref()?;
206        let file_scan: &FileScanConfig = dse.data_source().as_any().downcast_ref()?;
207
208        let d_cfg = cfg.extensions.get::<DistributedConfig>()?;
209
210        // Count how many partitioned files we have in the FileScanConfig.
211        let mut partitioned_files = 0;
212        for file_group in &file_scan.file_groups {
213            partitioned_files += file_group.len();
214        }
215
216        // Based on the user-provided files_per_task configuration, do the math to calculate
217        // how many tasks should be used, without surpassing the number of available workers.
218        let task_count = partitioned_files.div_ceil(d_cfg.files_per_task);
219
220        Some(TaskEstimation {
221            task_count: TaskCountAnnotation::Desired(task_count),
222        })
223    }
224
225    fn scale_up_leaf_node(
226        &self,
227        plan: &Arc<dyn ExecutionPlan>,
228        task_count: usize,
229        _cfg: &ConfigOptions,
230    ) -> Option<Arc<dyn ExecutionPlan>> {
231        if task_count == 1 {
232            return Some(Arc::clone(plan));
233        }
234        // Based on the task count, attempt to scale up the partitions in the DataSourceExec by
235        // repartitioning it. This will result in a DataSourceExec with potentially a lot of
236        // partitions, but as we are going to wrap it with PartitionIsolatorExec, that's fine.
237        let dse: &DataSourceExec = plan.as_any().downcast_ref()?;
238        let file_scan: &FileScanConfig = dse.data_source().as_any().downcast_ref()?;
239
240        let mut new_file_scan = file_scan.clone();
241        new_file_scan.file_groups.clear();
242        for file_group in file_scan.file_groups.clone() {
243            new_file_scan
244                .file_groups
245                .extend(file_group.split_files(task_count));
246        }
247        let plan = DataSourceExec::from_data_source(new_file_scan);
248        Some(Arc::new(PartitionIsolatorExec::new(plan, task_count)))
249    }
250}
251
252/// Tries multiple user-provided [TaskEstimator]s until one returns an estimation. If none
253/// returns an estimation, a set of default [TaskEstimation] implementations is tried. Right
254/// now the only default [TaskEstimation] is [FileScanConfigTaskEstimator].
255#[derive(Clone, Default)]
256pub(crate) struct CombinedTaskEstimator {
257    pub(crate) user_provided: Vec<Arc<dyn TaskEstimator + Send + Sync>>,
258}
259
260impl TaskEstimator for CombinedTaskEstimator {
261    fn task_estimation(
262        &self,
263        plan: &Arc<dyn ExecutionPlan>,
264        cfg: &ConfigOptions,
265    ) -> Option<TaskEstimation> {
266        for estimator in &self.user_provided {
267            if let Some(result) = estimator.task_estimation(plan, cfg) {
268                return Some(result);
269            }
270        }
271        // We want to execute the default estimators last so that the user-provided ones have
272        // a chance of providing an estimation.
273        // If none of the user-provided returned an estimation, the default ones are used.
274        for default_estimator in [&FileScanConfigTaskEstimator as &dyn TaskEstimator] {
275            if let Some(result) = default_estimator.task_estimation(plan, cfg) {
276                return Some(result);
277            }
278        }
279        None
280    }
281
282    fn scale_up_leaf_node(
283        &self,
284        plan: &Arc<dyn ExecutionPlan>,
285        task_count: usize,
286        cfg: &ConfigOptions,
287    ) -> Option<Arc<dyn ExecutionPlan>> {
288        for estimator in &self.user_provided {
289            if let Some(result) = estimator.scale_up_leaf_node(plan, task_count, cfg) {
290                return Some(result);
291            }
292        }
293        // We want to execute the default estimators last so that the user-provided ones have
294        // a chance of providing an estimation.
295        // If none of the user-provided returned an estimation, the default ones are used.
296        for default_estimator in [&FileScanConfigTaskEstimator as &dyn TaskEstimator] {
297            if let Some(result) = default_estimator.scale_up_leaf_node(plan, task_count, cfg) {
298                return Some(result);
299            }
300        }
301        None
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308    use crate::networking::WorkerResolverExtension;
309    use crate::test_utils::in_memory_channel_resolver::InMemoryWorkerResolver;
310    use crate::test_utils::parquet::register_parquet_tables;
311    use datafusion::error::DataFusionError;
312    use datafusion::prelude::SessionContext;
313
314    #[tokio::test]
315    async fn test_first_user_estimator_wins() -> Result<(), DataFusionError> {
316        let mut combined = CombinedTaskEstimator::default();
317        combined.push(10);
318        combined.push(20);
319
320        let node = make_data_source_exec().await?;
321        assert_eq!(combined.task_count(node, |cfg| cfg), 10);
322        Ok(())
323    }
324
325    #[tokio::test]
326    async fn test_continues_until_some() -> Result<(), DataFusionError> {
327        let mut combined = CombinedTaskEstimator::default();
328        combined.push(|_: &Arc<dyn ExecutionPlan>, _: &ConfigOptions| None);
329        combined.push(30);
330
331        let node = make_data_source_exec().await?;
332        assert_eq!(combined.task_count(node, |cfg| cfg), 30);
333        Ok(())
334    }
335
336    #[tokio::test]
337    async fn test_defaults_to_file_scan_config_task_estimator() -> Result<(), DataFusionError> {
338        let mut combined = CombinedTaskEstimator::default();
339        combined.push(|_: &Arc<dyn ExecutionPlan>, _: &ConfigOptions| None);
340
341        let node = make_data_source_exec().await?;
342        assert_eq!(combined.task_count(node, |cfg| cfg), 3);
343        Ok(())
344    }
345
346    impl CombinedTaskEstimator {
347        fn push(&mut self, value: impl TaskEstimator + Send + Sync + 'static) {
348            self.user_provided.push(Arc::new(value));
349        }
350
351        fn task_count(
352            &self,
353            node: Arc<dyn ExecutionPlan>,
354            f: impl FnOnce(DistributedConfig) -> DistributedConfig,
355        ) -> usize {
356            let mut cfg = ConfigOptions::default();
357            let d_cfg = DistributedConfig {
358                files_per_task: 1,
359                __private_worker_resolver: WorkerResolverExtension(Arc::new(
360                    InMemoryWorkerResolver::new(3),
361                )),
362                ..Default::default()
363            };
364            cfg.extensions.insert(f(d_cfg));
365            self.task_estimation(&node, &cfg)
366                .unwrap()
367                .task_count
368                .as_usize()
369        }
370    }
371
372    async fn make_data_source_exec() -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
373        let ctx = SessionContext::new();
374        register_parquet_tables(&ctx).await?;
375        let mut plan = ctx
376            .sql("SELECT * FROM weather")
377            .await?
378            .create_physical_plan()
379            .await?;
380        while !plan.children().is_empty() {
381            plan = Arc::clone(plan.children()[0])
382        }
383        Ok(plan)
384    }
385
386    impl<F: Fn(&Arc<dyn ExecutionPlan>, &ConfigOptions) -> Option<TaskEstimation>> TaskEstimator for F {
387        fn task_estimation(
388            &self,
389            plan: &Arc<dyn ExecutionPlan>,
390            cfg: &ConfigOptions,
391        ) -> Option<TaskEstimation> {
392            self(plan, cfg)
393        }
394
395        fn scale_up_leaf_node(
396            &self,
397            _plan: &Arc<dyn ExecutionPlan>,
398            _task_count: usize,
399            _cfg: &ConfigOptions,
400        ) -> Option<Arc<dyn ExecutionPlan>> {
401            None
402        }
403    }
404}