Skip to main content

datafusion_distributed/worker/
worker_service.rs

1use crate::worker::WorkerSessionBuilder;
2use crate::worker::generated::worker::worker_service_server::{WorkerService, WorkerServiceServer};
3use crate::worker::generated::worker::{
4    CoordinatorToWorkerMsg, ExecuteTaskRequest, TaskKey, WorkerToCoordinatorMsg,
5};
6use crate::worker::impl_execute_task::execute_remote_task;
7use crate::worker::single_write_multi_read::SingleWriteMultiRead;
8use crate::worker::task_data::TaskData;
9use crate::{
10    DefaultSessionBuilder, GetWorkerInfoRequest, GetWorkerInfoResponse, ObservabilityServiceImpl,
11    ObservabilityServiceServer, WorkerResolver,
12};
13use arrow_flight::FlightData;
14use async_trait::async_trait;
15use datafusion::common::DataFusionError;
16use datafusion::execution::runtime_env::RuntimeEnv;
17use datafusion::physical_plan::ExecutionPlan;
18use moka::future::Cache;
19use std::borrow::Cow;
20use std::sync::Arc;
21use std::time::Duration;
22use tonic::codegen::BoxStream;
23use tonic::{Request, Response, Status, Streaming};
24
25const TASK_CACHE_TTI: Duration = Duration::from_mins(10);
26
27#[allow(clippy::type_complexity)]
28#[derive(Clone, Default)]
29pub(super) struct WorkerHooks {
30    pub(super) on_plan:
31        Vec<Arc<dyn Fn(Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> + Sync + Send>>,
32}
33
34pub(crate) type ResultTaskData = Result<TaskData, Arc<DataFusionError>>;
35pub(crate) type TaskDataEntries = Cache<TaskKey, Arc<SingleWriteMultiRead<ResultTaskData>>>;
36
37#[derive(Clone)]
38pub struct Worker {
39    pub(super) runtime: Arc<RuntimeEnv>,
40    /// TTL-based cache for task execution data. Entries are automatically evicted after
41    /// TASK_CACHE_TTI seconds. This prevents memory leaks from abandoned or incomplete queries
42    /// while allowing concurrent access to task results across multiple partition requests.
43    pub(super) task_data_entries: Arc<TaskDataEntries>,
44    pub(super) session_builder: Arc<dyn WorkerSessionBuilder + Send + Sync>,
45    pub(super) hooks: WorkerHooks,
46    pub(super) max_message_size: Option<usize>,
47    pub(super) version: Cow<'static, str>,
48}
49
50impl Default for Worker {
51    fn default() -> Self {
52        let cache = Cache::builder().time_to_idle(TASK_CACHE_TTI).build();
53        Self {
54            runtime: Arc::new(RuntimeEnv::default()),
55            task_data_entries: Arc::new(cache),
56            session_builder: Arc::new(DefaultSessionBuilder),
57            hooks: WorkerHooks::default(),
58            max_message_size: Some(usize::MAX),
59            version: Cow::Borrowed(""),
60        }
61    }
62}
63
64impl Worker {
65    /// Builds a [Worker] with a custom [WorkerSessionBuilder]. Use this
66    /// method whenever you need to add custom stuff to the `SessionContext` that executes the query.
67    pub fn from_session_builder(
68        session_builder: impl WorkerSessionBuilder + Send + Sync + 'static,
69    ) -> Self {
70        Self {
71            session_builder: Arc::new(session_builder),
72            ..Default::default()
73        }
74    }
75
76    /// Sets a [RuntimeEnv] to be used in all the queries this [Worker] will handle during
77    /// its lifetime.
78    pub fn with_runtime_env(mut self, runtime_env: Arc<RuntimeEnv>) -> Self {
79        self.runtime = runtime_env;
80        self
81    }
82
83    /// Adds a callback for when an [ExecutionPlan] is received in the `set_plan` call.
84    ///
85    /// The callback takes the plan and returns another plan that must be either the same,
86    /// or equivalent in terms of execution. Mutating the plan by adding nodes or removing them
87    /// will make the query blow up in unexpected ways.
88    pub fn add_on_plan_hook(
89        &mut self,
90        hook: impl Fn(Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> + Sync + Send + 'static,
91    ) {
92        self.hooks.on_plan.push(Arc::new(hook));
93    }
94
95    /// Set the maximum message size for FlightData chunks.
96    ///
97    /// Defaults to `usize::MAX` to minimize chunking overhead for internal communication.
98    /// See [`FlightDataEncoderBuilder::with_max_flight_data_size`] for details.
99    ///
100    /// If you change this to a lower value, ensure you configure the server's
101    /// max_encoding_message_size and max_decoding_message_size to at least 2x this value
102    /// to allow for overhead. For most use cases, the default of `usize::MAX` is appropriate.
103    ///
104    /// [`FlightDataEncoderBuilder::with_max_flight_data_size`]: https://arrow.apache.org/rust/arrow_flight/encode/struct.FlightDataEncoderBuilder.html#structfield.max_flight_data_size
105    pub fn with_max_message_size(mut self, size: usize) -> Self {
106        self.max_message_size = Some(size);
107        self
108    }
109
110    /// Converts this [Worker] into a [`WorkerServiceServer`] with high default message size limits.
111    ///
112    /// This is a convenience method that wraps the endpoint in a [`WorkerServiceServer`] and
113    /// configures it with `max_decoding_message_size(usize::MAX)` and
114    /// `max_encoding_message_size(usize::MAX)` to avoid message size limitations for internal
115    /// communication.
116    ///
117    /// You can further customize the returned server by chaining additional tonic methods.
118    ///
119    /// # Example
120    ///
121    /// ```
122    /// # use datafusion_distributed::Worker;
123    /// # use tonic::transport::Server;
124    /// # use std::net::{IpAddr, Ipv4Addr, SocketAddr};
125    /// # async fn f() {
126    ///
127    /// let worker = Worker::default();
128    /// let server = worker.into_worker_server();
129    ///
130    /// Server::builder()
131    ///     .add_service(Worker::default().into_worker_server())
132    ///     .serve(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080))
133    ///     .await;
134    ///
135    /// # }
136    /// ```
137    pub fn into_worker_server(self) -> WorkerServiceServer<Self> {
138        WorkerServiceServer::new(self)
139            .max_decoding_message_size(usize::MAX)
140            .max_encoding_message_size(usize::MAX)
141    }
142
143    /// Creates an [`ObservabilityServiceServer`] that exposes task progress and cluster
144    /// worker discovery via the provided [`WorkerResolver`].
145    ///
146    /// The returned server is meant to be added to the same [`tonic::transport::Server`] as the
147    /// Flight service — gRPC multiplexes both services on a single port.
148    pub fn with_observability_service(
149        &self,
150        worker_resolver: Arc<dyn WorkerResolver + Send + Sync>,
151    ) -> ObservabilityServiceServer<ObservabilityServiceImpl> {
152        ObservabilityServiceServer::new(ObservabilityServiceImpl::new(
153            self.task_data_entries.clone(),
154            worker_resolver,
155        ))
156    }
157
158    /// Sets a version string reported by the `GetWorkerInfo` gRPC endpoint.
159    pub fn with_version(mut self, version: impl Into<Cow<'static, str>>) -> Self {
160        self.version = version.into();
161        self
162    }
163
164    /// Returns the number of cached task entries currently held by this worker.
165    #[cfg(any(test, feature = "integration"))]
166    pub async fn tasks_running(&self) -> usize {
167        // Use `run_pending_tasks()` to migigate inaccuracy from potential stale
168        // `entry_count()` task data.
169        self.task_data_entries.run_pending_tasks().await;
170        self.task_data_entries.entry_count() as usize
171    }
172}
173
174/// Implementation of the `worker.proto` specification based on the generated Rust stubs.
175///
176/// The methods are delegated to plan `impl Worker` implementations so that they can be implemented
177/// in different files.
178#[async_trait]
179impl WorkerService for Worker {
180    type CoordinatorChannelStream = BoxStream<WorkerToCoordinatorMsg>;
181
182    async fn coordinator_channel(
183        &self,
184        request: Request<Streaming<CoordinatorToWorkerMsg>>,
185    ) -> Result<Response<Self::CoordinatorChannelStream>, Status> {
186        self.impl_coordinator_channel(request).await
187    }
188
189    type ExecuteTaskStream = BoxStream<FlightData>;
190
191    async fn execute_task(
192        &self,
193        request: Request<ExecuteTaskRequest>,
194    ) -> Result<Response<Self::ExecuteTaskStream>, Status> {
195        execute_remote_task(&self.task_data_entries, request).await
196    }
197
198    async fn get_worker_info(
199        &self,
200        _request: Request<GetWorkerInfoRequest>,
201    ) -> Result<Response<GetWorkerInfoResponse>, Status> {
202        Ok(Response::new(GetWorkerInfoResponse {
203            version: self.version.to_string(),
204        }))
205    }
206}