Skip to main content

datafusion_distributed/work_unit_feed/
work_unit_feed_provider.rs

1use crate::WorkUnit;
2use datafusion::common::Result;
3use datafusion::execution::TaskContext;
4use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet;
5use futures::stream::BoxStream;
6use std::fmt::Debug;
7use std::sync::Arc;
8
9/// Extension point for building user-defined work unit streams consumed by a
10/// [`crate::WorkUnitFeed`] embedded in a leaf [`datafusion::physical_plan::ExecutionPlan`].
11///
12/// Implement this trait on a type that knows how to produce the per-partition stream of
13/// work items (e.g. file addresses, external queries, key ranges) that the leaf plan needs
14/// at runtime. Then wrap the implementation with [`crate::WorkUnitFeed::new`] and store
15/// the resulting [`crate::WorkUnitFeed`] as a field of your [`ExecutionPlan`] node.
16///
17/// In a distributed context the provider is only invoked on the **coordinating** stage
18/// that initiates the query. The work units it produces are serialized and streamed over
19/// the network to the workers, which expose the same typed stream to the leaf plan as if
20/// it were running locally.
21///
22/// See [`WorkUnitFeedProvider::feed`] for the per-call contract.
23pub trait WorkUnitFeedProvider: Send + Sync + Debug {
24    type WorkUnit: WorkUnit + Default;
25
26    /// Builds a [`WorkUnit`] stream for the given `partition`.
27    ///
28    /// This method is never invoked in a remote worker. On workers, the equivalent
29    /// leaf plan uses a remote provider that pulls the work units off the network —
30    /// user code doesn't need to implement that case.
31    ///
32    /// When implementing this method, [DistributedWorkUnitFeedContext] can be extracted from
33    /// the [TaskContext], and it contains information about the amount of distributed tasks to
34    /// which [WorkUnit]s should be fanned out.
35    ///
36    /// The implementation should be prepared to return `P*T` feeds, where `P` is the number of
37    /// partitions of the [datafusion::physical_plan::ExecutionPlan] to which the
38    /// [WorkUnitFeedProvider] is attached and `T` is the number of tasks to which it should fanout
39    ///
40    /// For more information about how [WorkUnit] feeds work, refer to the [crate::WorkUnitFeed]
41    /// docs.
42    ///
43    /// # Example
44    ///
45    /// ```rust
46    /// # use std::sync::Arc;
47    /// # use datafusion_distributed::{DistributedWorkUnitFeedContext, WorkUnitFeedProvider};
48    /// # use datafusion::common::Result;
49    /// # use datafusion::execution::TaskContext;
50    /// # use futures::stream::BoxStream;
51    /// # use futures::StreamExt;
52    ///
53    /// #[derive(Debug)]
54    /// struct MyFeedProvider {
55    ///     output_partitions: usize
56    /// };
57    ///
58    /// #[derive(Clone, PartialEq, ::prost::Message)]
59    /// struct MyCustomWorkUnit {
60    ///     #[prost(string, tag = "1")]
61    ///     custom_field: String,
62    /// }
63    ///
64    /// impl WorkUnitFeedProvider for MyFeedProvider {
65    ///     type WorkUnit = MyCustomWorkUnit;
66    ///
67    ///     fn feed(
68    ///         &self,
69    ///         partition: usize,
70    ///         ctx: Arc<TaskContext>,
71    ///     ) -> Result<BoxStream<'static, Result<Self::WorkUnit>>> {
72    ///         let feed_ctx = DistributedWorkUnitFeedContext::from_ctx(&ctx);
73    ///
74    ///         // this method will be called `feed_ctx.fan_out_tasks * self.output_partitions`
75    ///         // times.
76    ///         Ok(futures::stream::empty().boxed())
77    ///     }
78    /// }
79    /// ```
80    fn feed(
81        &self,
82        partition: usize,
83        ctx: Arc<TaskContext>,
84    ) -> Result<BoxStream<'static, Result<Self::WorkUnit>>>;
85
86    /// DataFusion metrics collected at runtime while streaming [WorkUnit]s through [Self::feed].
87    fn metrics(&self) -> ExecutionPlanMetricsSet {
88        ExecutionPlanMetricsSet::new()
89    }
90}
91
92/// Provides contextual information about where a [WorkUnitFeedProvider] is being executed. When
93/// using [WorkUnitFeedProvider] in distributed queries, it might be getting executed in the
94/// coordinating stage, or it might be getting executed just locally because the query did not
95/// need any remote execution.
96pub struct DistributedWorkUnitFeedContext {
97    /// The number of distributed tasks to which the [WorkUnitFeedProvider] should fan out.
98    pub fan_out_tasks: usize,
99}
100
101impl DistributedWorkUnitFeedContext {
102    /// Gets the [DistributedWorkUnitFeedContext] from the [TaskContext] as an extension.
103    /// If no [DistributedWorkUnitFeedContext] is present, returns one valid for single-node
104    /// execution.
105    pub fn from_ctx(ctx: &Arc<TaskContext>) -> Arc<Self> {
106        ctx.session_config()
107            .get_extension::<Self>()
108            .unwrap_or(Arc::new(DistributedWorkUnitFeedContext {
109                fan_out_tasks: 1,
110            }))
111    }
112}