use crate::common::deserialize_uuid;
use crate::work_unit_feed::{RemoteWorkUnitFeedRegistry, set_work_unit_received_time};
use crate::worker::LocalWorkerContext;
use crate::worker::generated::worker::coordinator_to_worker_msg::Inner;
use crate::worker::generated::worker::set_plan_request::WorkUnitFeedDeclaration;
use crate::worker::generated::worker::worker_service_server::WorkerService;
use crate::worker::generated::worker::{
CoordinatorToWorkerMsg, WorkerToCoordinatorMsg, worker_to_coordinator_msg,
};
use crate::worker::task_data::TaskDataMetrics;
use crate::{
DistributedCodec, DistributedConfig, DistributedExt, DistributedTaskContext, TaskData, Worker,
WorkerQueryContext,
};
use datafusion::common::DataFusionError;
use datafusion::execution::SessionStateBuilder;
use datafusion::prelude::SessionConfig;
use datafusion_proto::physical_plan::AsExecutionPlan;
use datafusion_proto::protobuf::PhysicalPlanNode;
use futures::{FutureExt, StreamExt, TryStreamExt};
use std::sync::atomic::AtomicUsize;
use std::sync::{Arc, OnceLock};
use tokio::sync::oneshot;
use tonic::{Request, Response, Status, Streaming};
use url::Url;
impl Worker {
pub(super) async fn impl_coordinator_channel(
&self,
request: Request<Streaming<CoordinatorToWorkerMsg>>,
) -> Result<Response<<Worker as WorkerService>::CoordinatorChannelStream>, Status> {
let (grpc_headers, _ext, mut body) = request.into_parts();
let Some(msg) = body.next().await else {
return Err(Status::internal("Empty Coordinator stream"));
};
let Some(Inner::SetPlanRequest(request)) = msg?.inner else {
return Err(Status::internal(
"First Coordinator message must be SetPlanRequest",
));
};
let key = request.task_key.ok_or_else(missing("task_key"))?;
let entry = self
.task_data_entries
.get_with(key.clone(), async { Default::default() })
.await;
let mut remote_work_unit_feed_registry = RemoteWorkUnitFeedRegistry::default();
for WorkUnitFeedDeclaration { id, partitions } in &request.work_unit_feed_declarations {
if let Ok(id) = deserialize_uuid(id) {
remote_work_unit_feed_registry.add(id, *partitions as usize);
}
}
let (metrics_tx, metrics_rx) = oneshot::channel();
let task_data = || async {
let headers = grpc_headers.into_headers();
let mut cfg = SessionConfig::default()
.with_extension(Arc::new(remote_work_unit_feed_registry.receivers))
.with_extension(Arc::new(DistributedTaskContext {
task_index: key.task_number as usize,
task_count: request.task_count as usize,
}))
.with_extension(Arc::new(LocalWorkerContext {
task_data_entries: Arc::clone(&self.task_data_entries),
self_url: Url::parse(&request.target_worker_url)
.map_err(|e| DataFusionError::External(Box::new(e)))?,
}))
.with_distributed_option_extension_from_headers::<DistributedConfig>(&headers)?;
let d_cfg = DistributedConfig::from_config_options(cfg.options())?;
let shuffle_batch_size = d_cfg.shuffle_batch_size;
let collect_metrics = d_cfg.collect_metrics;
if shuffle_batch_size != 0 {
cfg = cfg.with_batch_size(shuffle_batch_size);
}
let session_state = self
.session_builder
.build_session_state(WorkerQueryContext {
builder: SessionStateBuilder::new()
.with_default_features()
.with_config(cfg)
.with_runtime_env(Arc::clone(&self.runtime)),
headers,
})
.await?;
let codec = DistributedCodec::new_combined_with_user(session_state.config());
let task_ctx = session_state.task_ctx();
let proto_node = PhysicalPlanNode::try_decode(request.plan_proto.as_ref())?;
let mut plan = proto_node.try_into_physical_plan(&task_ctx, &codec)?;
for hook in self.hooks.on_plan.iter() {
plan = hook(plan)
}
let total_partitions = plan.properties().partitioning.partition_count();
Ok::<_, DataFusionError>(TaskData {
base_plan: plan,
final_plan: Arc::new(OnceLock::new()),
task_ctx,
num_partitions_remaining: Arc::new(AtomicUsize::new(total_partitions)),
metrics_tx: match collect_metrics {
true => Arc::new(std::sync::Mutex::new(Some(metrics_tx))),
false => Arc::new(std::sync::Mutex::new(None)),
},
task_data_metrics: Arc::new(TaskDataMetrics::new(request.query_start_time_ns)),
})
};
entry.write(task_data().await.map_err(Arc::new)).map_err(|_| {
Status::internal(format!(
"Logic error while setting plan for TaskKey {key:?}: the plan was set twice. This is a bug in datafusion-distributed, please report it."
))
})?;
let mut work_unit_senders = remote_work_unit_feed_registry.senders;
#[allow(clippy::disallowed_methods)]
tokio::spawn(async move {
let mut body = body.map_ok(set_work_unit_received_time);
while let Some(Ok(msg)) = body.next().await {
let Some(Inner::WorkUnitBatch(msg)) = msg.inner else {
continue;
};
for msg in msg.batch {
let Ok(id) = deserialize_uuid(&msg.id) else {
continue;
};
let partition = msg.partition as usize;
let Some(tx) = work_unit_senders.get(&(id, partition)) else {
continue;
};
if tx.send(Ok(msg)).is_err() {
work_unit_senders.remove(&(id, partition));
continue;
}
}
}
});
let metrics_stream = metrics_rx.into_stream();
let metrics_stream = metrics_stream.filter_map(|task_metrics| async move {
match task_metrics {
Ok(task_metrics) => Some(WorkerToCoordinatorMsg {
inner: Some(worker_to_coordinator_msg::Inner::TaskMetrics(task_metrics)),
}),
Err(_) => None, }
});
Ok(Response::new(metrics_stream.map(Ok).boxed()))
}
}
fn missing(field: &'static str) -> impl FnOnce() -> Status {
move || Status::invalid_argument(format!("Missing field '{field}'"))
}