use crate::worker::WorkerSessionBuilder;
use crate::worker::generated::worker::worker_service_server::{WorkerService, WorkerServiceServer};
use crate::worker::generated::worker::{
CoordinatorToWorkerMsg, ExecuteTaskRequest, TaskKey, WorkerToCoordinatorMsg,
};
use crate::worker::impl_execute_task::execute_remote_task;
use crate::worker::single_write_multi_read::SingleWriteMultiRead;
use crate::worker::task_data::TaskData;
use crate::{
DefaultSessionBuilder, GetWorkerInfoRequest, GetWorkerInfoResponse, ObservabilityServiceImpl,
ObservabilityServiceServer, WorkerResolver,
};
use arrow_flight::FlightData;
use async_trait::async_trait;
use datafusion::common::DataFusionError;
use datafusion::execution::runtime_env::RuntimeEnv;
use datafusion::physical_plan::ExecutionPlan;
use moka::future::Cache;
use std::borrow::Cow;
use std::sync::Arc;
use std::time::Duration;
use tonic::codegen::BoxStream;
use tonic::{Request, Response, Status, Streaming};
const TASK_CACHE_TTI: Duration = Duration::from_mins(10);
#[allow(clippy::type_complexity)]
#[derive(Clone, Default)]
pub(super) struct WorkerHooks {
pub(super) on_plan:
Vec<Arc<dyn Fn(Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> + Sync + Send>>,
}
pub(crate) type ResultTaskData = Result<TaskData, Arc<DataFusionError>>;
pub(crate) type TaskDataEntries = Cache<TaskKey, Arc<SingleWriteMultiRead<ResultTaskData>>>;
#[derive(Clone)]
pub struct Worker {
pub(super) runtime: Arc<RuntimeEnv>,
pub(super) task_data_entries: Arc<TaskDataEntries>,
pub(super) session_builder: Arc<dyn WorkerSessionBuilder + Send + Sync>,
pub(super) hooks: WorkerHooks,
pub(super) max_message_size: Option<usize>,
pub(super) version: Cow<'static, str>,
}
impl Default for Worker {
fn default() -> Self {
let cache = Cache::builder().time_to_idle(TASK_CACHE_TTI).build();
Self {
runtime: Arc::new(RuntimeEnv::default()),
task_data_entries: Arc::new(cache),
session_builder: Arc::new(DefaultSessionBuilder),
hooks: WorkerHooks::default(),
max_message_size: Some(usize::MAX),
version: Cow::Borrowed(""),
}
}
}
impl Worker {
pub fn from_session_builder(
session_builder: impl WorkerSessionBuilder + Send + Sync + 'static,
) -> Self {
Self {
session_builder: Arc::new(session_builder),
..Default::default()
}
}
pub fn with_runtime_env(mut self, runtime_env: Arc<RuntimeEnv>) -> Self {
self.runtime = runtime_env;
self
}
pub fn add_on_plan_hook(
&mut self,
hook: impl Fn(Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> + Sync + Send + 'static,
) {
self.hooks.on_plan.push(Arc::new(hook));
}
pub fn with_max_message_size(mut self, size: usize) -> Self {
self.max_message_size = Some(size);
self
}
pub fn into_worker_server(self) -> WorkerServiceServer<Self> {
WorkerServiceServer::new(self)
.max_decoding_message_size(usize::MAX)
.max_encoding_message_size(usize::MAX)
}
pub fn with_observability_service(
&self,
worker_resolver: Arc<dyn WorkerResolver + Send + Sync>,
) -> ObservabilityServiceServer<ObservabilityServiceImpl> {
ObservabilityServiceServer::new(ObservabilityServiceImpl::new(
self.task_data_entries.clone(),
worker_resolver,
))
}
pub fn with_version(mut self, version: impl Into<Cow<'static, str>>) -> Self {
self.version = version.into();
self
}
#[cfg(any(test, feature = "integration"))]
pub async fn tasks_running(&self) -> usize {
self.task_data_entries.run_pending_tasks().await;
self.task_data_entries.entry_count() as usize
}
}
#[async_trait]
impl WorkerService for Worker {
type CoordinatorChannelStream = BoxStream<WorkerToCoordinatorMsg>;
async fn coordinator_channel(
&self,
request: Request<Streaming<CoordinatorToWorkerMsg>>,
) -> Result<Response<Self::CoordinatorChannelStream>, Status> {
self.impl_coordinator_channel(request).await
}
type ExecuteTaskStream = BoxStream<FlightData>;
async fn execute_task(
&self,
request: Request<ExecuteTaskRequest>,
) -> Result<Response<Self::ExecuteTaskStream>, Status> {
execute_remote_task(&self.task_data_entries, request).await
}
async fn get_worker_info(
&self,
_request: Request<GetWorkerInfoRequest>,
) -> Result<Response<GetWorkerInfoResponse>, Status> {
Ok(Response::new(GetWorkerInfoResponse {
version: self.version.to_string(),
}))
}
}