datafusion_distributed/worker/
impl_set_plan.rs1use crate::config_extension_ext::set_distributed_option_extension_from_headers;
2use crate::protobuf::DistributedCodec;
3use crate::worker::generated::worker::SetPlanRequest;
4use crate::{DistributedConfig, DistributedTaskContext, Worker, WorkerQueryContext};
5use datafusion::error::DataFusionError;
6use datafusion::execution::{SessionStateBuilder, TaskContext};
7use datafusion::physical_plan::ExecutionPlan;
8use datafusion::prelude::SessionConfig;
9use datafusion_proto::physical_plan::AsExecutionPlan;
10use datafusion_proto::protobuf::PhysicalPlanNode;
11use std::sync::Arc;
12use std::sync::atomic::AtomicUsize;
13use tonic::Status;
14use tonic::metadata::MetadataMap;
15
16#[derive(Clone, Debug)]
17pub struct TaskData {
20 pub(super) task_ctx: Arc<TaskContext>,
22 pub(crate) plan: Arc<dyn ExecutionPlan>,
24 pub(super) num_partitions_remaining: Arc<AtomicUsize>,
30}
31
32impl TaskData {
33 pub(crate) fn num_partitions_remaining(&self) -> usize {
35 self.num_partitions_remaining
36 .load(std::sync::atomic::Ordering::Relaxed)
37 }
38
39 pub(crate) fn total_partitions(&self) -> usize {
41 self.plan.properties().partitioning.partition_count()
42 }
43}
44
45impl Worker {
46 pub(crate) async fn impl_set_plan(
47 &self,
48 request: SetPlanRequest,
49 metadata: MetadataMap,
50 ) -> Result<(), Status> {
51 let key = request.task_key.ok_or_else(missing("task_key"))?;
52
53 let entry = self
54 .task_data_entries
55 .get_with(key.clone(), async { Default::default() })
56 .await;
57
58 let task_data = || async {
59 let headers = metadata.into_headers();
60
61 let mut cfg =
62 SessionConfig::default().with_extension(Arc::new(DistributedTaskContext {
63 task_index: key.task_number as usize,
64 task_count: request.task_count as usize,
65 }));
66 set_distributed_option_extension_from_headers::<DistributedConfig>(&mut cfg, &headers)?;
67 let session_state = self
68 .session_builder
69 .build_session_state(WorkerQueryContext {
70 builder: SessionStateBuilder::new()
71 .with_default_features()
72 .with_config(cfg)
73 .with_runtime_env(Arc::clone(&self.runtime)),
74 headers,
75 })
76 .await?;
77
78 let codec = DistributedCodec::new_combined_with_user(session_state.config());
79 let task_ctx = session_state.task_ctx();
80 let proto_node = PhysicalPlanNode::try_decode(request.plan_proto.as_ref())?;
81 let mut plan = proto_node.try_into_physical_plan(&task_ctx, &codec)?;
82
83 for hook in self.hooks.on_plan.iter() {
84 plan = hook(plan)
85 }
86
87 let total_partitions = plan.properties().partitioning.partition_count();
89 Ok::<_, DataFusionError>(TaskData {
90 plan,
91 task_ctx,
92 num_partitions_remaining: Arc::new(AtomicUsize::new(total_partitions)),
93 })
94 };
95
96 entry.write(task_data().await.map_err(Arc::new)).map_err(|_| {
97 Status::internal(format!(
98 "Logic error while setting plan for TaskKey {key:?}: the plan was set twice. This is a bug in datafusion-distributed, please report it."
99 ))
100 })?;
101 Ok(())
102 }
103}
104
105fn missing(field: &'static str) -> impl FnOnce() -> Status {
106 move || Status::invalid_argument(format!("Missing field '{field}'"))
107}