datafusion_distributed/worker/
worker_service.rs1use crate::worker::WorkerSessionBuilder;
2use crate::worker::generated::worker::worker_service_server::{WorkerService, WorkerServiceServer};
3use crate::worker::generated::worker::{
4 CoordinatorToWorkerMsg, ExecuteTaskRequest, TaskKey, WorkerToCoordinatorMsg,
5};
6use crate::worker::impl_execute_task::execute_remote_task;
7use crate::worker::single_write_multi_read::SingleWriteMultiRead;
8use crate::worker::task_data::TaskData;
9use crate::{
10 DefaultSessionBuilder, GetWorkerInfoRequest, GetWorkerInfoResponse, ObservabilityServiceImpl,
11 ObservabilityServiceServer, WorkerResolver,
12};
13use arrow_flight::FlightData;
14use async_trait::async_trait;
15use datafusion::common::DataFusionError;
16use datafusion::execution::runtime_env::RuntimeEnv;
17use datafusion::physical_plan::ExecutionPlan;
18use moka::future::Cache;
19use std::borrow::Cow;
20use std::sync::Arc;
21use std::time::Duration;
22use tonic::codegen::BoxStream;
23use tonic::{Request, Response, Status, Streaming};
24
25const TASK_CACHE_TTI: Duration = Duration::from_mins(10);
26
27#[allow(clippy::type_complexity)]
28#[derive(Clone, Default)]
29pub(super) struct WorkerHooks {
30 pub(super) on_plan:
31 Vec<Arc<dyn Fn(Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> + Sync + Send>>,
32}
33
34pub(crate) type ResultTaskData = Result<TaskData, Arc<DataFusionError>>;
35pub(crate) type TaskDataEntries = Cache<TaskKey, Arc<SingleWriteMultiRead<ResultTaskData>>>;
36
37#[derive(Clone)]
38pub struct Worker {
39 pub(super) runtime: Arc<RuntimeEnv>,
40 pub(super) task_data_entries: Arc<TaskDataEntries>,
44 pub(super) session_builder: Arc<dyn WorkerSessionBuilder + Send + Sync>,
45 pub(super) hooks: WorkerHooks,
46 pub(super) max_message_size: Option<usize>,
47 pub(super) version: Cow<'static, str>,
48}
49
50impl Default for Worker {
51 fn default() -> Self {
52 let cache = Cache::builder().time_to_idle(TASK_CACHE_TTI).build();
53 Self {
54 runtime: Arc::new(RuntimeEnv::default()),
55 task_data_entries: Arc::new(cache),
56 session_builder: Arc::new(DefaultSessionBuilder),
57 hooks: WorkerHooks::default(),
58 max_message_size: Some(usize::MAX),
59 version: Cow::Borrowed(""),
60 }
61 }
62}
63
64impl Worker {
65 pub fn from_session_builder(
68 session_builder: impl WorkerSessionBuilder + Send + Sync + 'static,
69 ) -> Self {
70 Self {
71 session_builder: Arc::new(session_builder),
72 ..Default::default()
73 }
74 }
75
76 pub fn with_runtime_env(mut self, runtime_env: Arc<RuntimeEnv>) -> Self {
79 self.runtime = runtime_env;
80 self
81 }
82
83 pub fn add_on_plan_hook(
89 &mut self,
90 hook: impl Fn(Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> + Sync + Send + 'static,
91 ) {
92 self.hooks.on_plan.push(Arc::new(hook));
93 }
94
95 pub fn with_max_message_size(mut self, size: usize) -> Self {
106 self.max_message_size = Some(size);
107 self
108 }
109
110 pub fn into_worker_server(self) -> WorkerServiceServer<Self> {
138 WorkerServiceServer::new(self)
139 .max_decoding_message_size(usize::MAX)
140 .max_encoding_message_size(usize::MAX)
141 }
142
143 pub fn with_observability_service(
149 &self,
150 worker_resolver: Arc<dyn WorkerResolver + Send + Sync>,
151 ) -> ObservabilityServiceServer<ObservabilityServiceImpl> {
152 ObservabilityServiceServer::new(ObservabilityServiceImpl::new(
153 self.task_data_entries.clone(),
154 worker_resolver,
155 ))
156 }
157
158 pub fn with_version(mut self, version: impl Into<Cow<'static, str>>) -> Self {
160 self.version = version.into();
161 self
162 }
163
164 #[cfg(any(test, feature = "integration"))]
166 pub async fn tasks_running(&self) -> usize {
167 self.task_data_entries.run_pending_tasks().await;
170 self.task_data_entries.entry_count() as usize
171 }
172}
173
174#[async_trait]
179impl WorkerService for Worker {
180 type CoordinatorChannelStream = BoxStream<WorkerToCoordinatorMsg>;
181
182 async fn coordinator_channel(
183 &self,
184 request: Request<Streaming<CoordinatorToWorkerMsg>>,
185 ) -> Result<Response<Self::CoordinatorChannelStream>, Status> {
186 self.impl_coordinator_channel(request).await
187 }
188
189 type ExecuteTaskStream = BoxStream<FlightData>;
190
191 async fn execute_task(
192 &self,
193 request: Request<ExecuteTaskRequest>,
194 ) -> Result<Response<Self::ExecuteTaskStream>, Status> {
195 execute_remote_task(&self.task_data_entries, request).await
196 }
197
198 async fn get_worker_info(
199 &self,
200 _request: Request<GetWorkerInfoRequest>,
201 ) -> Result<Response<GetWorkerInfoResponse>, Status> {
202 Ok(Response::new(GetWorkerInfoResponse {
203 version: self.version.to_string(),
204 }))
205 }
206}