Skip to main content

datafusion_distributed/worker/
worker_service.rs

1use crate::worker::WorkerSessionBuilder;
2use crate::worker::generated::worker::coordinator_to_worker_msg::Inner;
3use crate::worker::generated::worker::worker_service_server::{WorkerService, WorkerServiceServer};
4use crate::worker::generated::worker::{
5    CoordinatorToWorkerMsg, ExecuteTaskRequest, TaskKey, WorkerToCoordinatorMsg,
6};
7use crate::worker::impl_set_plan::TaskData;
8use crate::worker::single_write_multi_read::SingleWriteMultiRead;
9use crate::{
10    DefaultSessionBuilder, ObservabilityServiceImpl, ObservabilityServiceServer, WorkerResolver,
11};
12use arrow_flight::FlightData;
13use async_trait::async_trait;
14use datafusion::common::DataFusionError;
15use datafusion::execution::runtime_env::RuntimeEnv;
16use datafusion::physical_plan::ExecutionPlan;
17use futures::StreamExt;
18use moka::future::Cache;
19use std::borrow::Cow;
20use std::sync::Arc;
21use std::time::Duration;
22use tonic::codegen::BoxStream;
23use tonic::{Request, Response, Status, Streaming};
24
25use super::generated::worker::{GetWorkerInfoRequest, GetWorkerInfoResponse};
26
27#[allow(clippy::type_complexity)]
28#[derive(Clone, Default)]
29pub(super) struct WorkerHooks {
30    pub(super) on_plan:
31        Vec<Arc<dyn Fn(Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> + Sync + Send>>,
32}
33
34type ResultTaskData = Result<TaskData, Arc<DataFusionError>>;
35
36#[derive(Clone)]
37pub struct Worker {
38    pub(super) runtime: Arc<RuntimeEnv>,
39    /// TTL-based cache for task execution data. Entries are automatically evicted after 60 seconds.
40    /// This prevents memory leaks from abandoned or incomplete queries while allowing concurrent
41    /// access to task results across multiple partition requests.
42    pub(super) task_data_entries: Arc<Cache<TaskKey, Arc<SingleWriteMultiRead<ResultTaskData>>>>,
43    pub(super) session_builder: Arc<dyn WorkerSessionBuilder + Send + Sync>,
44    pub(super) hooks: WorkerHooks,
45    pub(super) max_message_size: Option<usize>,
46    pub(super) version: Cow<'static, str>,
47}
48
49impl Default for Worker {
50    fn default() -> Self {
51        let cache = Cache::builder()
52            .time_to_idle(Duration::from_secs(60))
53            .build();
54        Self {
55            runtime: Arc::new(RuntimeEnv::default()),
56            task_data_entries: Arc::new(cache),
57            session_builder: Arc::new(DefaultSessionBuilder),
58            hooks: WorkerHooks::default(),
59            max_message_size: Some(usize::MAX),
60            version: Cow::Borrowed(""),
61        }
62    }
63}
64
65impl Worker {
66    /// Builds a [Worker] with a custom [WorkerSessionBuilder]. Use this
67    /// method whenever you need to add custom stuff to the `SessionContext` that executes the query.
68    pub fn from_session_builder(
69        session_builder: impl WorkerSessionBuilder + Send + Sync + 'static,
70    ) -> Self {
71        Self {
72            session_builder: Arc::new(session_builder),
73            ..Default::default()
74        }
75    }
76
77    /// Sets a [RuntimeEnv] to be used in all the queries this [Worker] will handle during
78    /// its lifetime.
79    pub fn with_runtime_env(mut self, runtime_env: Arc<RuntimeEnv>) -> Self {
80        self.runtime = runtime_env;
81        self
82    }
83
84    /// Adds a callback for when an [ExecutionPlan] is received in the `set_plan` call.
85    ///
86    /// The callback takes the plan and returns another plan that must be either the same,
87    /// or equivalent in terms of execution. Mutating the plan by adding nodes or removing them
88    /// will make the query blow up in unexpected ways.
89    pub fn add_on_plan_hook(
90        &mut self,
91        hook: impl Fn(Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> + Sync + Send + 'static,
92    ) {
93        self.hooks.on_plan.push(Arc::new(hook));
94    }
95
96    /// Set the maximum message size for FlightData chunks.
97    ///
98    /// Defaults to `usize::MAX` to minimize chunking overhead for internal communication.
99    /// See [`FlightDataEncoderBuilder::with_max_flight_data_size`] for details.
100    ///
101    /// If you change this to a lower value, ensure you configure the server's
102    /// max_encoding_message_size and max_decoding_message_size to at least 2x this value
103    /// to allow for overhead. For most use cases, the default of `usize::MAX` is appropriate.
104    ///
105    /// [`FlightDataEncoderBuilder::with_max_flight_data_size`]: https://arrow.apache.org/rust/arrow_flight/encode/struct.FlightDataEncoderBuilder.html#structfield.max_flight_data_size
106    pub fn with_max_message_size(mut self, size: usize) -> Self {
107        self.max_message_size = Some(size);
108        self
109    }
110
111    /// Converts this [Worker] into a [`WorkerServiceServer`] with high default message size limits.
112    ///
113    /// This is a convenience method that wraps the endpoint in a [`WorkerServiceServer`] and
114    /// configures it with `max_decoding_message_size(usize::MAX)` and
115    /// `max_encoding_message_size(usize::MAX)` to avoid message size limitations for internal
116    /// communication.
117    ///
118    /// You can further customize the returned server by chaining additional tonic methods.
119    ///
120    /// # Example
121    ///
122    /// ```
123    /// # use datafusion_distributed::Worker;
124    /// # use tonic::transport::Server;
125    /// # use std::net::{IpAddr, Ipv4Addr, SocketAddr};
126    /// # async fn f() {
127    ///
128    /// let worker = Worker::default();
129    /// let server = worker.into_worker_server();
130    ///
131    /// Server::builder()
132    ///     .add_service(Worker::default().into_worker_server())
133    ///     .serve(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080))
134    ///     .await;
135    ///
136    /// # }
137    /// ```
138    pub fn into_worker_server(self) -> WorkerServiceServer<Self> {
139        WorkerServiceServer::new(self)
140            .max_decoding_message_size(usize::MAX)
141            .max_encoding_message_size(usize::MAX)
142    }
143
144    /// Creates an [`ObservabilityServiceServer`] that exposes task progress and cluster
145    /// worker discovery via the provided [`WorkerResolver`].
146    ///
147    /// The returned server is meant to be added to the same [`tonic::transport::Server`] as the
148    /// Flight service — gRPC multiplexes both services on a single port.
149    pub fn with_observability_service(
150        &self,
151        worker_resolver: Arc<dyn WorkerResolver + Send + Sync>,
152    ) -> ObservabilityServiceServer<ObservabilityServiceImpl> {
153        ObservabilityServiceServer::new(ObservabilityServiceImpl::new(
154            self.task_data_entries.clone(),
155            worker_resolver,
156        ))
157    }
158
159    /// Sets a version string reported by the `GetWorkerInfo` gRPC endpoint.
160    pub fn with_version(mut self, version: impl Into<Cow<'static, str>>) -> Self {
161        self.version = version.into();
162        self
163    }
164
165    /// Returns the number of cached task entries currently held by this worker.
166    #[cfg(any(test, feature = "integration"))]
167    pub async fn tasks_running(&self) -> usize {
168        // Use `run_pending_tasks()` to migigate inaccuracy from potential stale
169        // `entry_count()` task data.
170        self.task_data_entries.run_pending_tasks().await;
171        self.task_data_entries.entry_count() as usize
172    }
173}
174
175/// Implementation of the `worker.proto` specification based on the generated Rust stubs.
176///
177/// The methods are delegated to plan `impl Worker` implementations so that they can be implemented
178/// in different files.
179#[async_trait]
180impl WorkerService for Worker {
181    type CoordinatorChannelStream = BoxStream<WorkerToCoordinatorMsg>;
182
183    async fn coordinator_channel(
184        &self,
185        request: Request<Streaming<CoordinatorToWorkerMsg>>,
186    ) -> Result<Response<Self::CoordinatorChannelStream>, Status> {
187        let (metadata, _ext, mut body) = request.into_parts();
188        if let Some(msg) = body.next().await {
189            let Some(inner) = msg?.inner else {
190                return Err(Status::internal("Empty Coordinator message"));
191            };
192
193            match inner {
194                Inner::SetPlanRequest(request) => {
195                    self.impl_set_plan(request, metadata).await?;
196                }
197            };
198        }
199        Ok(Response::new(futures::stream::empty().boxed()))
200    }
201
202    type ExecuteTaskStream = BoxStream<FlightData>;
203
204    async fn execute_task(
205        &self,
206        request: Request<ExecuteTaskRequest>,
207    ) -> Result<Response<Self::ExecuteTaskStream>, Status> {
208        self.impl_execute_task(request).await
209    }
210
211    async fn get_worker_info(
212        &self,
213        _request: Request<GetWorkerInfoRequest>,
214    ) -> Result<Response<GetWorkerInfoResponse>, Status> {
215        Ok(Response::new(GetWorkerInfoResponse {
216            version: self.version.to_string(),
217        }))
218    }
219}