1use std::{collections::HashMap, sync::Arc};
2
3use log::{debug, error};
4use parking_lot::Mutex;
5use tokio::sync::mpsc::{Receiver, Sender};
6use uuid::Uuid;
7
8use crate::{
9 DistResult,
10 cluster::{DistCluster, NodeId},
11 config::DistConfig,
12 network::{DistNetwork, StageInfo},
13 planner::StageId,
14 runtime::StageState,
15};
16
17#[derive(Debug, Clone)]
18pub enum Event {
19 CheckJobCompleted(Uuid),
20 CleanupJob(Uuid),
21 ReceivedStage0Tasks(Vec<StageId>),
22}
23
24pub fn start_event_handler(mut handler: EventHandler) {
25 tokio::spawn(async move {
26 handler.start().await;
27 });
28}
29
30pub struct EventHandler {
31 pub local_node: NodeId,
32 pub config: Arc<DistConfig>,
33 pub cluster: Arc<dyn DistCluster>,
34 pub network: Arc<dyn DistNetwork>,
35 pub local_stages: Arc<Mutex<HashMap<StageId, StageState>>>,
36 pub sender: Sender<Event>,
37 pub receiver: Receiver<Event>,
38}
39
40impl EventHandler {
41 pub async fn start(&mut self) {
42 while let Some(event) = self.receiver.recv().await {
43 debug!("Received event: {event:?}");
44 match event {
45 Event::CheckJobCompleted(job_id) => {
46 self.handle_check_job_completed(job_id).await;
47 }
48 Event::CleanupJob(job_id) => {
49 self.handle_cleanup_job(job_id).await;
50 }
51 Event::ReceivedStage0Tasks(stage0_ids) => {
52 self.handle_received_stage0_tasks(stage0_ids).await;
53 }
54 }
55 }
56 }
57
58 async fn handle_check_job_completed(&mut self, job_id: Uuid) {
59 match check_job_completed(&self.cluster, &self.network, &self.local_stages, job_id).await {
60 Ok(Some(true)) => {
61 debug!("Job {job_id} completed, remove it from cluster");
62
63 if let Err(e) = self.sender.send(Event::CleanupJob(job_id)).await {
64 error!("Failed to send cleanup job event for job {job_id}: {e}");
65 }
66 }
67 Ok(_) => {}
68 Err(err) => {
69 error!("Failed to check job {job_id} completed: {err}");
70 }
71 }
72 }
73
74 async fn handle_cleanup_job(&mut self, job_id: Uuid) {
75 if let Err(e) = cleanup_job(
76 &self.local_node,
77 &self.cluster,
78 &self.network,
79 &self.local_stages,
80 job_id,
81 )
82 .await
83 {
84 error!("Failed to cleanup job {job_id}: {e}");
85 }
86 }
87
88 async fn handle_received_stage0_tasks(&self, stage0_ids: Vec<StageId>) {
89 let stage0_task_poll_timeout = self.config.stage0_task_poll_timeout;
90 let local_stages = self.local_stages.clone();
91 let sender = self.sender.clone();
92 tokio::spawn(async move {
93 tokio::time::sleep(stage0_task_poll_timeout).await;
94
95 let mut timeout_stage0_id = None;
96 {
97 let stages_guard = local_stages.lock();
98 for stage_id in stage0_ids {
99 if let Some(stage) = stages_guard.get(&stage_id)
100 && stage.never_executed()
101 {
102 debug!("Found stage0 {stage_id} never polled until timeout");
103 timeout_stage0_id = Some(stage_id);
104 break;
105 }
106 }
107 drop(stages_guard);
108 }
109
110 if let Some(stage_id) = timeout_stage0_id
111 && let Err(e) = sender.send(Event::CleanupJob(stage_id.job_id)).await
112 {
113 error!(
114 "Failed to send CleanupJob event for job {}: {e}",
115 stage_id.job_id
116 );
117 }
118 });
119 }
120}
121
122pub async fn check_job_completed(
123 cluster: &Arc<dyn DistCluster>,
124 network: &Arc<dyn DistNetwork>,
125 local_stages: &Arc<Mutex<HashMap<StageId, StageState>>>,
126 job_id: Uuid,
127) -> DistResult<Option<bool>> {
128 let mut combined_status = local_stage_stats(local_stages, Some(job_id));
130
131 let node_states = cluster.alive_nodes().await?;
133
134 let local_node_id = network.local_node();
135
136 let mut handles = Vec::new();
137 for node_id in node_states.keys() {
138 if *node_id != local_node_id {
139 let network = network.clone();
140 let node_id = node_id.clone();
141 let handle =
142 tokio::spawn(async move { network.get_job_status(node_id, Some(job_id)).await });
143 handles.push(handle);
144 }
145 }
146
147 for handle in handles {
148 let remote_status = handle.await??;
149 for (stage_id, remote_stage_info) in remote_status {
150 combined_status
151 .entry(stage_id)
152 .and_modify(|existing| {
153 existing
154 .assigned_partitions
155 .extend(&remote_stage_info.assigned_partitions);
156 existing
157 .task_set_infos
158 .extend(remote_stage_info.task_set_infos.clone());
159 })
160 .or_insert(remote_stage_info);
161 }
162 }
163
164 let stage0 = StageId { job_id, stage: 0 };
165
166 let Some(stage0_info) = combined_status.get(&stage0) else {
167 return Ok(None);
168 };
169
170 for partition in &stage0_info.assigned_partitions {
172 let is_completed = stage0_info
173 .task_set_infos
174 .iter()
175 .any(|ts| ts.dropped_partitions.contains_key(partition));
176 if !is_completed {
177 return Ok(Some(false));
178 }
179 }
180
181 Ok(Some(true))
182}
183
184pub fn local_stage_stats(
185 stages: &Arc<Mutex<HashMap<StageId, StageState>>>,
186 job_id: Option<Uuid>,
187) -> HashMap<StageId, StageInfo> {
188 let guard = stages.lock();
189
190 let mut result = HashMap::new();
191 for (stage_id, stage_state) in guard.iter() {
192 if job_id.is_none() || stage_id.job_id == job_id.unwrap() {
193 let stage_info = StageInfo::from_stage_state(stage_state);
194 result.insert(*stage_id, stage_info);
195 }
196 }
197
198 result
199}
200
201pub async fn cleanup_job(
202 local_node: &NodeId,
203 cluster: &Arc<dyn DistCluster>,
204 network: &Arc<dyn DistNetwork>,
205 local_stages: &Arc<Mutex<HashMap<StageId, StageState>>>,
206 job_id: Uuid,
207) -> DistResult<()> {
208 let alive_nodes = cluster.alive_nodes().await?;
209
210 for node_id in alive_nodes.keys() {
211 if node_id == local_node {
212 let mut guard = local_stages.lock();
213 guard.retain(|stage_id, _| stage_id.job_id != job_id);
214 drop(guard);
215 } else {
216 network.cleanup_job(node_id.clone(), job_id).await?
218 }
219 }
220 Ok(())
221}