use crate::config_extension_ext::set_distributed_option_extension;
use crate::{DistributedConfig, WorkUnit, WorkUnitFeed, WorkUnitFeedProvider};
use datafusion::common::Result;
use datafusion::execution::TaskContext;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionConfig;
use futures::StreamExt;
use futures::stream::BoxStream;
use std::sync::Arc;
use uuid::Uuid;
pub(crate) trait ErasedWorkUnitFeed: Send + Sync {
fn id(&self) -> Uuid;
fn feed(
&self,
partition: usize,
ctx: Arc<TaskContext>,
) -> Result<BoxStream<'static, Result<Box<dyn WorkUnit>>>>;
}
impl<T> ErasedWorkUnitFeed for WorkUnitFeed<T>
where
T: WorkUnitFeedProvider + 'static,
T::WorkUnit: 'static,
{
fn id(&self) -> Uuid {
self.id
}
fn feed(
&self,
partition: usize,
ctx: Arc<TaskContext>,
) -> Result<BoxStream<'static, Result<Box<dyn WorkUnit>>>> {
let stream = WorkUnitFeed::feed(self, partition, ctx)?;
Ok(stream
.map(|res| res.map(|wu| Box::new(wu) as Box<dyn WorkUnit>))
.boxed())
}
}
trait WorkUnitFeedGetter: Send + Sync {
fn get_work_unit_feed<'a>(
&self,
node: &'a Arc<dyn ExecutionPlan>,
) -> Option<&'a dyn ErasedWorkUnitFeed>;
}
impl<T, F> WorkUnitFeedGetter for F
where
T: WorkUnitFeedProvider + 'static,
T::WorkUnit: 'static,
F: for<'a> Fn(&'a Arc<dyn ExecutionPlan>) -> Option<&'a WorkUnitFeed<T>>
+ Send
+ Sync
+ 'static,
{
fn get_work_unit_feed<'a>(
&self,
node: &'a Arc<dyn ExecutionPlan>,
) -> Option<&'a dyn ErasedWorkUnitFeed> {
(self)(node).map(|feed| feed as &dyn ErasedWorkUnitFeed)
}
}
#[derive(Default, Clone)]
pub(crate) struct WorkUnitFeedRegistry {
entries: Vec<Arc<dyn WorkUnitFeedGetter>>,
}
impl WorkUnitFeedRegistry {
pub(crate) fn get_work_unit_feed<'a>(
&self,
node: &'a Arc<dyn ExecutionPlan>,
) -> Option<&'a dyn ErasedWorkUnitFeed> {
for entry in &self.entries {
if let Some(feed) = entry.get_work_unit_feed(node) {
return Some(feed);
}
}
None
}
}
pub(crate) fn set_distributed_work_unit_feed<T, F>(cfg: &mut SessionConfig, getter: F)
where
T: WorkUnitFeedProvider + 'static,
T::WorkUnit: 'static,
F: Fn(&Arc<dyn ExecutionPlan>) -> Option<&WorkUnitFeed<T>> + Send + Sync + 'static,
{
let opts = cfg.options_mut();
if let Some(distributed_cfg) = opts.extensions.get_mut::<DistributedConfig>() {
distributed_cfg
.__private_work_unit_feed_registry
.entries
.push(Arc::new(getter));
} else {
let mut registry = WorkUnitFeedRegistry::default();
registry.entries.push(Arc::new(getter));
set_distributed_option_extension(
cfg,
DistributedConfig {
__private_work_unit_feed_registry: registry,
..Default::default()
},
)
}
}