Skip to main content

datafusion_dist/
network.rs

1use std::{
2    collections::{HashMap, HashSet},
3    fmt::Debug,
4    sync::Arc,
5};
6
7use datafusion_execution::SendableRecordBatchStream;
8use datafusion_physical_plan::ExecutionPlan;
9use serde::{Deserialize, Serialize};
10
11use crate::{
12    DistError, DistResult, JobId,
13    cluster::NodeId,
14    planner::{StageId, TaskId},
15    runtime::{StageState, TaskMetrics, TaskSet},
16};
17
18#[async_trait::async_trait]
19pub trait DistNetwork: Debug + Send + Sync {
20    fn local_node(&self) -> NodeId;
21
22    async fn send_tasks(&self, node_id: NodeId, scheduled_tasks: ScheduledTasks) -> DistResult<()>;
23
24    async fn execute_task(
25        &self,
26        node_id: NodeId,
27        task_id: TaskId,
28    ) -> DistResult<SendableRecordBatchStream>;
29
30    async fn get_jobs(
31        &self,
32        node_id: NodeId,
33        job_ids: Option<Vec<JobId>>,
34    ) -> DistResult<HashMap<StageId, StageInfo>>;
35
36    async fn cleanup_jobs(&self, node_id: NodeId, job_ids: Vec<JobId>) -> DistResult<()>;
37}
38
39pub struct ScheduledTasks {
40    pub stage_plans: HashMap<StageId, Arc<dyn ExecutionPlan>>,
41    pub task_ids: Vec<TaskId>,
42    pub job_task_distribution: Arc<HashMap<TaskId, NodeId>>,
43    pub job_meta: Arc<HashMap<String, String>>,
44}
45
46impl ScheduledTasks {
47    pub fn new(
48        stage_plans: HashMap<StageId, Arc<dyn ExecutionPlan>>,
49        task_ids: Vec<TaskId>,
50        job_task_distribution: Arc<HashMap<TaskId, NodeId>>,
51        job_meta: Arc<HashMap<String, String>>,
52    ) -> Self {
53        ScheduledTasks {
54            stage_plans,
55            task_ids,
56            job_task_distribution,
57            job_meta,
58        }
59    }
60
61    pub fn job_id(&self) -> DistResult<JobId> {
62        self.task_ids
63            .first()
64            .map(|task_id| task_id.job_id.clone())
65            .ok_or_else(|| DistError::internal("ScheduledTasks has no task_ids".to_string()))
66    }
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct StageInfo {
71    pub created_at_ms: i64,
72    pub assigned_partitions: HashSet<usize>,
73    pub task_set_infos: Vec<TaskSetInfo>,
74    pub job_meta: Arc<HashMap<String, String>>,
75}
76
77impl StageInfo {
78    pub fn from_stage_state(stage_state: &StageState) -> Self {
79        let task_set_infos = stage_state
80            .task_sets
81            .iter()
82            .map(TaskSetInfo::from_task_set)
83            .collect();
84
85        StageInfo {
86            created_at_ms: stage_state.created_at_ms,
87            assigned_partitions: stage_state.assigned_partitions.clone(),
88            task_set_infos,
89            job_meta: stage_state.job_meta.clone(),
90        }
91    }
92
93    pub fn merge(&mut self, other: &StageInfo) {
94        self.assigned_partitions.extend(&other.assigned_partitions);
95        self.task_set_infos.extend(other.task_set_infos.clone());
96    }
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct TaskSetInfo {
101    pub running_partitions: HashSet<usize>,
102    pub dropped_partitions: HashMap<usize, TaskMetrics>,
103}
104
105impl TaskSetInfo {
106    pub fn from_task_set(task_set: &TaskSet) -> Self {
107        TaskSetInfo {
108            running_partitions: task_set.running_partitions.keys().copied().collect(),
109            dropped_partitions: task_set.dropped_partitions.clone(),
110        }
111    }
112}