Skip to main content

datafusion_distributed/worker/
impl_set_plan.rs

1use crate::config_extension_ext::set_distributed_option_extension_from_headers;
2use crate::protobuf::DistributedCodec;
3use crate::worker::generated::worker::SetPlanRequest;
4use crate::{DistributedConfig, DistributedTaskContext, Worker, WorkerQueryContext};
5use datafusion::error::DataFusionError;
6use datafusion::execution::{SessionStateBuilder, TaskContext};
7use datafusion::physical_plan::ExecutionPlan;
8use datafusion::prelude::SessionConfig;
9use datafusion_proto::physical_plan::AsExecutionPlan;
10use datafusion_proto::protobuf::PhysicalPlanNode;
11use std::sync::Arc;
12use std::sync::atomic::AtomicUsize;
13use tonic::Status;
14use tonic::metadata::MetadataMap;
15
16#[derive(Clone, Debug)]
17/// TaskData stores state for a single task being executed by this Endpoint. It may be shared
18/// by concurrent requests for the same task which execute separate partitions.
19pub struct TaskData {
20    /// Task context suitable for execute different partitions from the same task.
21    pub(super) task_ctx: Arc<TaskContext>,
22    /// Plan to be executed.
23    pub(crate) plan: Arc<dyn ExecutionPlan>,
24    /// `num_partitions_remaining` is initialized to the total number of partitions in the task (not
25    /// only tasks in the partition group). This is decremented for each request to the endpoint
26    /// for this task. Once this count is zero, the task is likely complete. The task may not be
27    /// complete because it's possible that the same partition was retried and this count was
28    /// decremented more than once for the same partition.
29    pub(super) num_partitions_remaining: Arc<AtomicUsize>,
30}
31
32impl TaskData {
33    /// Returns the number of partitions remaining to be processed.
34    pub(crate) fn num_partitions_remaining(&self) -> usize {
35        self.num_partitions_remaining
36            .load(std::sync::atomic::Ordering::Relaxed)
37    }
38
39    /// Returns the total number of partitions in this task.
40    pub(crate) fn total_partitions(&self) -> usize {
41        self.plan.properties().partitioning.partition_count()
42    }
43}
44
45impl Worker {
46    pub(crate) async fn impl_set_plan(
47        &self,
48        request: SetPlanRequest,
49        metadata: MetadataMap,
50    ) -> Result<(), Status> {
51        let key = request.task_key.ok_or_else(missing("task_key"))?;
52
53        let entry = self
54            .task_data_entries
55            .get_with(key.clone(), async { Default::default() })
56            .await;
57
58        let task_data = || async {
59            let headers = metadata.into_headers();
60
61            let mut cfg =
62                SessionConfig::default().with_extension(Arc::new(DistributedTaskContext {
63                    task_index: key.task_number as usize,
64                    task_count: request.task_count as usize,
65                }));
66            set_distributed_option_extension_from_headers::<DistributedConfig>(&mut cfg, &headers)?;
67            let session_state = self
68                .session_builder
69                .build_session_state(WorkerQueryContext {
70                    builder: SessionStateBuilder::new()
71                        .with_default_features()
72                        .with_config(cfg)
73                        .with_runtime_env(Arc::clone(&self.runtime)),
74                    headers,
75                })
76                .await?;
77
78            let codec = DistributedCodec::new_combined_with_user(session_state.config());
79            let task_ctx = session_state.task_ctx();
80            let proto_node = PhysicalPlanNode::try_decode(request.plan_proto.as_ref())?;
81            let mut plan = proto_node.try_into_physical_plan(&task_ctx, &codec)?;
82
83            for hook in self.hooks.on_plan.iter() {
84                plan = hook(plan)
85            }
86
87            // Initialize partition count to the number of partitions in the stage
88            let total_partitions = plan.properties().partitioning.partition_count();
89            Ok::<_, DataFusionError>(TaskData {
90                plan,
91                task_ctx,
92                num_partitions_remaining: Arc::new(AtomicUsize::new(total_partitions)),
93            })
94        };
95
96        entry.write(task_data().await.map_err(Arc::new)).map_err(|_| {
97            Status::internal(format!(
98                "Logic error while setting plan for TaskKey {key:?}: the plan was set twice. This is a bug in datafusion-distributed, please report it."
99            ))
100        })?;
101        Ok(())
102    }
103}
104
105fn missing(field: &'static str) -> impl FnOnce() -> Status {
106    move || Status::invalid_argument(format!("Missing field '{field}'"))
107}