use crate::config_extension_ext::set_distributed_option_extension;
use crate::{DistributedConfig, PartitionIsolatorExec};
use datafusion::catalog::memory::DataSourceExec;
use datafusion::config::ConfigOptions;
use datafusion::datasource::physical_plan::FileScanConfig;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionConfig;
use delegate::delegate;
use std::fmt::Debug;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub enum TaskCountAnnotation {
Desired(usize),
Maximum(usize),
}
impl From<TaskCountAnnotation> for usize {
fn from(annotation: TaskCountAnnotation) -> Self {
annotation.as_usize()
}
}
impl TaskCountAnnotation {
pub fn as_usize(&self) -> usize {
match self {
Self::Desired(desired) => *desired,
Self::Maximum(maximum) => *maximum,
}
}
pub(crate) fn limit(self, limit: usize) -> Self {
match self {
Self::Desired(desired) => Self::Desired(desired.min(limit)),
Self::Maximum(maximum) => Self::Maximum(maximum.min(limit)),
}
}
}
pub struct TaskEstimation {
pub task_count: TaskCountAnnotation,
}
impl TaskEstimation {
pub fn maximum(value: usize) -> Self {
TaskEstimation {
task_count: TaskCountAnnotation::Maximum(value),
}
}
pub fn desired(value: usize) -> Self {
TaskEstimation {
task_count: TaskCountAnnotation::Desired(value),
}
}
}
pub trait TaskEstimator {
fn task_estimation(
&self,
plan: &Arc<dyn ExecutionPlan>,
cfg: &ConfigOptions,
) -> Option<TaskEstimation>;
fn scale_up_leaf_node(
&self,
plan: &Arc<dyn ExecutionPlan>,
task_count: usize,
cfg: &ConfigOptions,
) -> Option<Arc<dyn ExecutionPlan>>;
}
impl TaskEstimator for usize {
fn task_estimation(
&self,
inputs: &Arc<dyn ExecutionPlan>,
_: &ConfigOptions,
) -> Option<TaskEstimation> {
if inputs.children().is_empty() {
Some(TaskEstimation {
task_count: TaskCountAnnotation::Desired(*self),
})
} else {
None
}
}
fn scale_up_leaf_node(
&self,
_: &Arc<dyn ExecutionPlan>,
_: usize,
_: &ConfigOptions,
) -> Option<Arc<dyn ExecutionPlan>> {
None
}
}
impl TaskEstimator for Arc<dyn TaskEstimator> {
delegate! {
to self.as_ref() {
fn task_estimation(&self, plan: &Arc<dyn ExecutionPlan>, cfg: &ConfigOptions) -> Option<TaskEstimation>;
fn scale_up_leaf_node(&self, plan: &Arc<dyn ExecutionPlan>, task_count: usize, cfg: &ConfigOptions) -> Option<Arc<dyn ExecutionPlan>>;
}
}
}
impl TaskEstimator for Arc<dyn TaskEstimator + Send + Sync> {
delegate! {
to self.as_ref() {
fn task_estimation(&self, plan: &Arc<dyn ExecutionPlan>, cfg: &ConfigOptions) -> Option<TaskEstimation>;
fn scale_up_leaf_node(&self, plan: &Arc<dyn ExecutionPlan>, task_count: usize, cfg: &ConfigOptions) -> Option<Arc<dyn ExecutionPlan>>;
}
}
}
pub(crate) fn set_distributed_task_estimator(
cfg: &mut SessionConfig,
estimator: impl TaskEstimator + Send + Sync + 'static,
) {
let opts = cfg.options_mut();
if let Some(distributed_cfg) = opts.extensions.get_mut::<DistributedConfig>() {
distributed_cfg
.__private_task_estimator
.user_provided
.push(Arc::new(estimator));
} else {
let mut estimators = CombinedTaskEstimator::default();
estimators.user_provided.push(Arc::new(estimator));
set_distributed_option_extension(
cfg,
DistributedConfig {
__private_task_estimator: estimators,
..Default::default()
},
)
}
}
#[derive(Debug)]
struct FileScanConfigTaskEstimator;
impl TaskEstimator for FileScanConfigTaskEstimator {
fn task_estimation(
&self,
plan: &Arc<dyn ExecutionPlan>,
cfg: &ConfigOptions,
) -> Option<TaskEstimation> {
let dse: &DataSourceExec = plan.as_any().downcast_ref()?;
let file_scan: &FileScanConfig = dse.data_source().as_any().downcast_ref()?;
let d_cfg = cfg.extensions.get::<DistributedConfig>()?;
let mut partitioned_files = 0;
for file_group in &file_scan.file_groups {
partitioned_files += file_group.len();
}
let task_count = partitioned_files.div_ceil(d_cfg.files_per_task);
Some(TaskEstimation {
task_count: TaskCountAnnotation::Desired(task_count),
})
}
fn scale_up_leaf_node(
&self,
plan: &Arc<dyn ExecutionPlan>,
task_count: usize,
_cfg: &ConfigOptions,
) -> Option<Arc<dyn ExecutionPlan>> {
if task_count == 1 {
return Some(Arc::clone(plan));
}
let dse: &DataSourceExec = plan.as_any().downcast_ref()?;
let file_scan: &FileScanConfig = dse.data_source().as_any().downcast_ref()?;
let mut new_file_scan = file_scan.clone();
new_file_scan.file_groups.clear();
for file_group in file_scan.file_groups.clone() {
new_file_scan
.file_groups
.extend(file_group.split_files(task_count));
}
let plan = DataSourceExec::from_data_source(new_file_scan);
Some(Arc::new(PartitionIsolatorExec::new(plan, task_count)))
}
}
#[derive(Clone, Default)]
pub(crate) struct CombinedTaskEstimator {
pub(crate) user_provided: Vec<Arc<dyn TaskEstimator + Send + Sync>>,
}
impl TaskEstimator for CombinedTaskEstimator {
fn task_estimation(
&self,
plan: &Arc<dyn ExecutionPlan>,
cfg: &ConfigOptions,
) -> Option<TaskEstimation> {
for estimator in &self.user_provided {
if let Some(result) = estimator.task_estimation(plan, cfg) {
return Some(result);
}
}
for default_estimator in [&FileScanConfigTaskEstimator as &dyn TaskEstimator] {
if let Some(result) = default_estimator.task_estimation(plan, cfg) {
return Some(result);
}
}
None
}
fn scale_up_leaf_node(
&self,
plan: &Arc<dyn ExecutionPlan>,
task_count: usize,
cfg: &ConfigOptions,
) -> Option<Arc<dyn ExecutionPlan>> {
for estimator in &self.user_provided {
if let Some(result) = estimator.scale_up_leaf_node(plan, task_count, cfg) {
return Some(result);
}
}
for default_estimator in [&FileScanConfigTaskEstimator as &dyn TaskEstimator] {
if let Some(result) = default_estimator.scale_up_leaf_node(plan, task_count, cfg) {
return Some(result);
}
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::networking::WorkerResolverExtension;
use crate::test_utils::in_memory_channel_resolver::InMemoryWorkerResolver;
use crate::test_utils::parquet::register_parquet_tables;
use datafusion::error::DataFusionError;
use datafusion::prelude::SessionContext;
#[tokio::test]
async fn test_first_user_estimator_wins() -> Result<(), DataFusionError> {
let mut combined = CombinedTaskEstimator::default();
combined.push(10);
combined.push(20);
let node = make_data_source_exec().await?;
assert_eq!(combined.task_count(node, |cfg| cfg), 10);
Ok(())
}
#[tokio::test]
async fn test_continues_until_some() -> Result<(), DataFusionError> {
let mut combined = CombinedTaskEstimator::default();
combined.push(|_: &Arc<dyn ExecutionPlan>, _: &ConfigOptions| None);
combined.push(30);
let node = make_data_source_exec().await?;
assert_eq!(combined.task_count(node, |cfg| cfg), 30);
Ok(())
}
#[tokio::test]
async fn test_defaults_to_file_scan_config_task_estimator() -> Result<(), DataFusionError> {
let mut combined = CombinedTaskEstimator::default();
combined.push(|_: &Arc<dyn ExecutionPlan>, _: &ConfigOptions| None);
let node = make_data_source_exec().await?;
assert_eq!(combined.task_count(node, |cfg| cfg), 3);
Ok(())
}
impl CombinedTaskEstimator {
fn push(&mut self, value: impl TaskEstimator + Send + Sync + 'static) {
self.user_provided.push(Arc::new(value));
}
fn task_count(
&self,
node: Arc<dyn ExecutionPlan>,
f: impl FnOnce(DistributedConfig) -> DistributedConfig,
) -> usize {
let mut cfg = ConfigOptions::default();
let d_cfg = DistributedConfig {
files_per_task: 1,
__private_worker_resolver: WorkerResolverExtension(Arc::new(
InMemoryWorkerResolver::new(3),
)),
..Default::default()
};
cfg.extensions.insert(f(d_cfg));
self.task_estimation(&node, &cfg)
.unwrap()
.task_count
.as_usize()
}
}
async fn make_data_source_exec() -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
let ctx = SessionContext::new();
register_parquet_tables(&ctx).await?;
let mut plan = ctx
.sql("SELECT * FROM weather")
.await?
.create_physical_plan()
.await?;
while !plan.children().is_empty() {
plan = Arc::clone(plan.children()[0])
}
Ok(plan)
}
impl<F: Fn(&Arc<dyn ExecutionPlan>, &ConfigOptions) -> Option<TaskEstimation>> TaskEstimator for F {
fn task_estimation(
&self,
plan: &Arc<dyn ExecutionPlan>,
cfg: &ConfigOptions,
) -> Option<TaskEstimation> {
self(plan, cfg)
}
fn scale_up_leaf_node(
&self,
_plan: &Arc<dyn ExecutionPlan>,
_task_count: usize,
_cfg: &ConfigOptions,
) -> Option<Arc<dyn ExecutionPlan>> {
None
}
}
}