datafusion_distributed/work_unit_feed/work_unit_feed_provider.rs
1use crate::WorkUnit;
2use datafusion::common::Result;
3use datafusion::execution::TaskContext;
4use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet;
5use futures::stream::BoxStream;
6use std::fmt::Debug;
7use std::sync::Arc;
8
9/// Extension point for building user-defined work unit streams consumed by a
10/// [`crate::WorkUnitFeed`] embedded in a leaf [`datafusion::physical_plan::ExecutionPlan`].
11///
12/// Implement this trait on a type that knows how to produce the per-partition stream of
13/// work items (e.g. file addresses, external queries, key ranges) that the leaf plan needs
14/// at runtime. Then wrap the implementation with [`crate::WorkUnitFeed::new`] and store
15/// the resulting [`crate::WorkUnitFeed`] as a field of your [`ExecutionPlan`] node.
16///
17/// In a distributed context the provider is only invoked on the **coordinating** stage
18/// that initiates the query. The work units it produces are serialized and streamed over
19/// the network to the workers, which expose the same typed stream to the leaf plan as if
20/// it were running locally.
21///
22/// See [`WorkUnitFeedProvider::feed`] for the per-call contract.
23pub trait WorkUnitFeedProvider: Send + Sync + Debug {
24 type WorkUnit: WorkUnit + Default;
25
26 /// Builds a [`WorkUnit`] stream for the given `partition`.
27 ///
28 /// This method is never invoked in a remote worker. On workers, the equivalent
29 /// leaf plan uses a remote provider that pulls the work units off the network —
30 /// user code doesn't need to implement that case.
31 ///
32 /// When implementing this method, [DistributedWorkUnitFeedContext] can be extracted from
33 /// the [TaskContext], and it contains information about the amount of distributed tasks to
34 /// which [WorkUnit]s should be fanned out.
35 ///
36 /// The implementation should be prepared to return `P*T` feeds, where `P` is the number of
37 /// partitions of the [datafusion::physical_plan::ExecutionPlan] to which the
38 /// [WorkUnitFeedProvider] is attached and `T` is the number of tasks to which it should fanout
39 ///
40 /// For more information about how [WorkUnit] feeds work, refer to the [crate::WorkUnitFeed]
41 /// docs.
42 ///
43 /// # Example
44 ///
45 /// ```rust
46 /// # use std::sync::Arc;
47 /// # use datafusion_distributed::{DistributedWorkUnitFeedContext, WorkUnitFeedProvider};
48 /// # use datafusion::common::Result;
49 /// # use datafusion::execution::TaskContext;
50 /// # use futures::stream::BoxStream;
51 /// # use futures::StreamExt;
52 ///
53 /// #[derive(Debug)]
54 /// struct MyFeedProvider {
55 /// output_partitions: usize
56 /// };
57 ///
58 /// #[derive(Clone, PartialEq, ::prost::Message)]
59 /// struct MyCustomWorkUnit {
60 /// #[prost(string, tag = "1")]
61 /// custom_field: String,
62 /// }
63 ///
64 /// impl WorkUnitFeedProvider for MyFeedProvider {
65 /// type WorkUnit = MyCustomWorkUnit;
66 ///
67 /// fn feed(
68 /// &self,
69 /// partition: usize,
70 /// ctx: Arc<TaskContext>,
71 /// ) -> Result<BoxStream<'static, Result<Self::WorkUnit>>> {
72 /// let feed_ctx = DistributedWorkUnitFeedContext::from_ctx(&ctx);
73 ///
74 /// // this method will be called `feed_ctx.fan_out_tasks * self.output_partitions`
75 /// // times.
76 /// Ok(futures::stream::empty().boxed())
77 /// }
78 /// }
79 /// ```
80 fn feed(
81 &self,
82 partition: usize,
83 ctx: Arc<TaskContext>,
84 ) -> Result<BoxStream<'static, Result<Self::WorkUnit>>>;
85
86 /// DataFusion metrics collected at runtime while streaming [WorkUnit]s through [Self::feed].
87 fn metrics(&self) -> ExecutionPlanMetricsSet {
88 ExecutionPlanMetricsSet::new()
89 }
90}
91
92/// Provides contextual information about where a [WorkUnitFeedProvider] is being executed. When
93/// using [WorkUnitFeedProvider] in distributed queries, it might be getting executed in the
94/// coordinating stage, or it might be getting executed just locally because the query did not
95/// need any remote execution.
96pub struct DistributedWorkUnitFeedContext {
97 /// The number of distributed tasks to which the [WorkUnitFeedProvider] should fan out.
98 pub fan_out_tasks: usize,
99}
100
101impl DistributedWorkUnitFeedContext {
102 /// Gets the [DistributedWorkUnitFeedContext] from the [TaskContext] as an extension.
103 /// If no [DistributedWorkUnitFeedContext] is present, returns one valid for single-node
104 /// execution.
105 pub fn from_ctx(ctx: &Arc<TaskContext>) -> Arc<Self> {
106 ctx.session_config()
107 .get_extension::<Self>()
108 .unwrap_or(Arc::new(DistributedWorkUnitFeedContext {
109 fan_out_tasks: 1,
110 }))
111 }
112}