datafusion_distributed/distributed_planner/
task_estimator.rs1use crate::config_extension_ext::set_distributed_option_extension;
2use crate::{DistributedConfig, PartitionIsolatorExec};
3use datafusion::catalog::memory::DataSourceExec;
4use datafusion::config::ConfigOptions;
5use datafusion::datasource::physical_plan::FileScanConfig;
6use datafusion::physical_plan::ExecutionPlan;
7use datafusion::prelude::SessionConfig;
8use delegate::delegate;
9use std::fmt::Debug;
10use std::sync::Arc;
11
12#[derive(Debug, Clone)]
15pub enum TaskCountAnnotation {
16 Desired(usize),
20 Maximum(usize),
23}
24
25impl From<TaskCountAnnotation> for usize {
26 fn from(annotation: TaskCountAnnotation) -> Self {
27 annotation.as_usize()
28 }
29}
30
31impl TaskCountAnnotation {
32 pub fn as_usize(&self) -> usize {
33 match self {
34 Self::Desired(desired) => *desired,
35 Self::Maximum(maximum) => *maximum,
36 }
37 }
38
39 pub(crate) fn limit(self, limit: usize) -> Self {
40 match self {
41 Self::Desired(desired) => Self::Desired(desired.min(limit)),
42 Self::Maximum(maximum) => Self::Maximum(maximum.min(limit)),
43 }
44 }
45}
46
47pub struct TaskEstimation {
50 pub task_count: TaskCountAnnotation,
59}
60
61impl TaskEstimation {
62 pub fn maximum(value: usize) -> Self {
71 TaskEstimation {
72 task_count: TaskCountAnnotation::Maximum(value),
73 }
74 }
75
76 pub fn desired(value: usize) -> Self {
83 TaskEstimation {
84 task_count: TaskCountAnnotation::Desired(value),
85 }
86 }
87}
88
89pub trait TaskEstimator {
97 fn task_estimation(
110 &self,
111 plan: &Arc<dyn ExecutionPlan>,
112 cfg: &ConfigOptions,
113 ) -> Option<TaskEstimation>;
114
115 fn scale_up_leaf_node(
119 &self,
120 plan: &Arc<dyn ExecutionPlan>,
121 task_count: usize,
122 cfg: &ConfigOptions,
123 ) -> Option<Arc<dyn ExecutionPlan>>;
124}
125
126impl TaskEstimator for usize {
127 fn task_estimation(
128 &self,
129 inputs: &Arc<dyn ExecutionPlan>,
130 _: &ConfigOptions,
131 ) -> Option<TaskEstimation> {
132 if inputs.children().is_empty() {
133 Some(TaskEstimation {
134 task_count: TaskCountAnnotation::Desired(*self),
135 })
136 } else {
137 None
138 }
139 }
140
141 fn scale_up_leaf_node(
142 &self,
143 _: &Arc<dyn ExecutionPlan>,
144 _: usize,
145 _: &ConfigOptions,
146 ) -> Option<Arc<dyn ExecutionPlan>> {
147 None
148 }
149}
150
151impl TaskEstimator for Arc<dyn TaskEstimator> {
152 delegate! {
153 to self.as_ref() {
154 fn task_estimation(&self, plan: &Arc<dyn ExecutionPlan>, cfg: &ConfigOptions) -> Option<TaskEstimation>;
155 fn scale_up_leaf_node(&self, plan: &Arc<dyn ExecutionPlan>, task_count: usize, cfg: &ConfigOptions) -> Option<Arc<dyn ExecutionPlan>>;
156 }
157 }
158}
159
160impl TaskEstimator for Arc<dyn TaskEstimator + Send + Sync> {
161 delegate! {
162 to self.as_ref() {
163 fn task_estimation(&self, plan: &Arc<dyn ExecutionPlan>, cfg: &ConfigOptions) -> Option<TaskEstimation>;
164 fn scale_up_leaf_node(&self, plan: &Arc<dyn ExecutionPlan>, task_count: usize, cfg: &ConfigOptions) -> Option<Arc<dyn ExecutionPlan>>;
165 }
166 }
167}
168
169pub(crate) fn set_distributed_task_estimator(
170 cfg: &mut SessionConfig,
171 estimator: impl TaskEstimator + Send + Sync + 'static,
172) {
173 let opts = cfg.options_mut();
174 if let Some(distributed_cfg) = opts.extensions.get_mut::<DistributedConfig>() {
175 distributed_cfg
176 .__private_task_estimator
177 .user_provided
178 .push(Arc::new(estimator));
179 } else {
180 let mut estimators = CombinedTaskEstimator::default();
181 estimators.user_provided.push(Arc::new(estimator));
182 set_distributed_option_extension(
183 cfg,
184 DistributedConfig {
185 __private_task_estimator: estimators,
186 ..Default::default()
187 },
188 )
189 }
190}
191
192#[derive(Debug)]
197struct FileScanConfigTaskEstimator;
198
199impl TaskEstimator for FileScanConfigTaskEstimator {
200 fn task_estimation(
201 &self,
202 plan: &Arc<dyn ExecutionPlan>,
203 cfg: &ConfigOptions,
204 ) -> Option<TaskEstimation> {
205 let dse: &DataSourceExec = plan.as_any().downcast_ref()?;
206 let file_scan: &FileScanConfig = dse.data_source().as_any().downcast_ref()?;
207
208 let d_cfg = cfg.extensions.get::<DistributedConfig>()?;
209
210 let mut partitioned_files = 0;
212 for file_group in &file_scan.file_groups {
213 partitioned_files += file_group.len();
214 }
215
216 let task_count = partitioned_files.div_ceil(d_cfg.files_per_task);
219
220 Some(TaskEstimation {
221 task_count: TaskCountAnnotation::Desired(task_count),
222 })
223 }
224
225 fn scale_up_leaf_node(
226 &self,
227 plan: &Arc<dyn ExecutionPlan>,
228 task_count: usize,
229 _cfg: &ConfigOptions,
230 ) -> Option<Arc<dyn ExecutionPlan>> {
231 if task_count == 1 {
232 return Some(Arc::clone(plan));
233 }
234 let dse: &DataSourceExec = plan.as_any().downcast_ref()?;
238 let file_scan: &FileScanConfig = dse.data_source().as_any().downcast_ref()?;
239
240 let mut new_file_scan = file_scan.clone();
241 new_file_scan.file_groups.clear();
242 for file_group in file_scan.file_groups.clone() {
243 new_file_scan
244 .file_groups
245 .extend(file_group.split_files(task_count));
246 }
247 let plan = DataSourceExec::from_data_source(new_file_scan);
248 Some(Arc::new(PartitionIsolatorExec::new(plan, task_count)))
249 }
250}
251
252#[derive(Clone, Default)]
256pub(crate) struct CombinedTaskEstimator {
257 pub(crate) user_provided: Vec<Arc<dyn TaskEstimator + Send + Sync>>,
258}
259
260impl TaskEstimator for CombinedTaskEstimator {
261 fn task_estimation(
262 &self,
263 plan: &Arc<dyn ExecutionPlan>,
264 cfg: &ConfigOptions,
265 ) -> Option<TaskEstimation> {
266 for estimator in &self.user_provided {
267 if let Some(result) = estimator.task_estimation(plan, cfg) {
268 return Some(result);
269 }
270 }
271 for default_estimator in [&FileScanConfigTaskEstimator as &dyn TaskEstimator] {
275 if let Some(result) = default_estimator.task_estimation(plan, cfg) {
276 return Some(result);
277 }
278 }
279 None
280 }
281
282 fn scale_up_leaf_node(
283 &self,
284 plan: &Arc<dyn ExecutionPlan>,
285 task_count: usize,
286 cfg: &ConfigOptions,
287 ) -> Option<Arc<dyn ExecutionPlan>> {
288 for estimator in &self.user_provided {
289 if let Some(result) = estimator.scale_up_leaf_node(plan, task_count, cfg) {
290 return Some(result);
291 }
292 }
293 for default_estimator in [&FileScanConfigTaskEstimator as &dyn TaskEstimator] {
297 if let Some(result) = default_estimator.scale_up_leaf_node(plan, task_count, cfg) {
298 return Some(result);
299 }
300 }
301 None
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308 use crate::networking::WorkerResolverExtension;
309 use crate::test_utils::in_memory_channel_resolver::InMemoryWorkerResolver;
310 use crate::test_utils::parquet::register_parquet_tables;
311 use datafusion::error::DataFusionError;
312 use datafusion::prelude::SessionContext;
313
314 #[tokio::test]
315 async fn test_first_user_estimator_wins() -> Result<(), DataFusionError> {
316 let mut combined = CombinedTaskEstimator::default();
317 combined.push(10);
318 combined.push(20);
319
320 let node = make_data_source_exec().await?;
321 assert_eq!(combined.task_count(node, |cfg| cfg), 10);
322 Ok(())
323 }
324
325 #[tokio::test]
326 async fn test_continues_until_some() -> Result<(), DataFusionError> {
327 let mut combined = CombinedTaskEstimator::default();
328 combined.push(|_: &Arc<dyn ExecutionPlan>, _: &ConfigOptions| None);
329 combined.push(30);
330
331 let node = make_data_source_exec().await?;
332 assert_eq!(combined.task_count(node, |cfg| cfg), 30);
333 Ok(())
334 }
335
336 #[tokio::test]
337 async fn test_defaults_to_file_scan_config_task_estimator() -> Result<(), DataFusionError> {
338 let mut combined = CombinedTaskEstimator::default();
339 combined.push(|_: &Arc<dyn ExecutionPlan>, _: &ConfigOptions| None);
340
341 let node = make_data_source_exec().await?;
342 assert_eq!(combined.task_count(node, |cfg| cfg), 3);
343 Ok(())
344 }
345
346 impl CombinedTaskEstimator {
347 fn push(&mut self, value: impl TaskEstimator + Send + Sync + 'static) {
348 self.user_provided.push(Arc::new(value));
349 }
350
351 fn task_count(
352 &self,
353 node: Arc<dyn ExecutionPlan>,
354 f: impl FnOnce(DistributedConfig) -> DistributedConfig,
355 ) -> usize {
356 let mut cfg = ConfigOptions::default();
357 let d_cfg = DistributedConfig {
358 files_per_task: 1,
359 __private_worker_resolver: WorkerResolverExtension(Arc::new(
360 InMemoryWorkerResolver::new(3),
361 )),
362 ..Default::default()
363 };
364 cfg.extensions.insert(f(d_cfg));
365 self.task_estimation(&node, &cfg)
366 .unwrap()
367 .task_count
368 .as_usize()
369 }
370 }
371
372 async fn make_data_source_exec() -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
373 let ctx = SessionContext::new();
374 register_parquet_tables(&ctx).await?;
375 let mut plan = ctx
376 .sql("SELECT * FROM weather")
377 .await?
378 .create_physical_plan()
379 .await?;
380 while !plan.children().is_empty() {
381 plan = Arc::clone(plan.children()[0])
382 }
383 Ok(plan)
384 }
385
386 impl<F: Fn(&Arc<dyn ExecutionPlan>, &ConfigOptions) -> Option<TaskEstimation>> TaskEstimator for F {
387 fn task_estimation(
388 &self,
389 plan: &Arc<dyn ExecutionPlan>,
390 cfg: &ConfigOptions,
391 ) -> Option<TaskEstimation> {
392 self(plan, cfg)
393 }
394
395 fn scale_up_leaf_node(
396 &self,
397 _plan: &Arc<dyn ExecutionPlan>,
398 _task_count: usize,
399 _cfg: &ConfigOptions,
400 ) -> Option<Arc<dyn ExecutionPlan>> {
401 None
402 }
403 }
404}