1use std::{
2 collections::{HashMap, HashSet},
3 sync::Arc,
4 time::Duration,
5};
6
7use backon::{ExponentialBuilder, Retryable};
8use futures::future::join_all;
9use log::{debug, error, warn};
10use parking_lot::Mutex;
11use tokio::sync::mpsc::{Receiver, Sender};
12
13use crate::{
14 DistError, DistResult, JobId,
15 cluster::{DistCluster, NodeId},
16 config::DistConfig,
17 network::{DistNetwork, StageInfo},
18 planner::StageId,
19 runtime::{StageState, cleanup_stages},
20};
21
22#[derive(Debug, Clone)]
23pub enum Event {
24 CheckJobCompleted(JobId),
25 CleanupJob(JobId),
26 ReceivedStage0Tasks(Vec<StageId>),
27}
28
29const MAX_BATCH_SIZE: usize = 1024;
30const EVENT_SEND_TIMEOUT: Duration = Duration::from_secs(300);
31const CHECK_JOB_RETRY_MAX_DELAY: Duration = Duration::from_secs(10);
32const CHECK_JOB_RETRY_MAX_TIMES: usize = 3;
33
34fn job_check_retry_strategy() -> ExponentialBuilder {
35 ExponentialBuilder::default()
36 .with_max_delay(CHECK_JOB_RETRY_MAX_DELAY)
37 .with_max_times(CHECK_JOB_RETRY_MAX_TIMES)
38 .with_jitter()
39}
40
41pub async fn send_event_with_timeout(sender: &Sender<Event>, event: Event) -> DistResult<()> {
42 tokio::time::timeout(EVENT_SEND_TIMEOUT, sender.send(event))
43 .await
44 .map_err(|_| {
45 DistError::internal(format!(
46 "Timed out sending event after {}s",
47 EVENT_SEND_TIMEOUT.as_secs()
48 ))
49 })?
50 .map_err(|e| DistError::internal(format!("Failed to send event: {e}")))
51}
52
53fn merge_events(events: &mut Vec<Event>) -> Vec<Event> {
61 let mut merged: Vec<Event> = Vec::with_capacity(events.len());
62 let mut seen_check_jobs = HashSet::with_capacity(events.len());
63 let mut seen_cleanup_jobs = HashSet::with_capacity(events.len());
64 let mut stage0_ids = Vec::new();
65
66 for event in events.drain(..) {
67 match event {
68 Event::CheckJobCompleted(job_id) => {
69 if seen_check_jobs.insert(job_id.clone()) {
70 merged.push(Event::CheckJobCompleted(job_id));
71 }
72 }
73 Event::CleanupJob(job_id) => {
74 if seen_cleanup_jobs.insert(job_id.clone()) {
75 merged.push(Event::CleanupJob(job_id));
76 }
77 }
78 Event::ReceivedStage0Tasks(mut ids) => {
79 if !ids.is_empty() {
80 stage0_ids.append(&mut ids);
81 }
82 }
83 }
84 }
85
86 if !stage0_ids.is_empty() {
87 merged.push(Event::ReceivedStage0Tasks(stage0_ids));
88 }
89
90 merged
91}
92
93pub fn start_event_handler(mut handler: EventHandler) {
94 tokio::spawn(async move {
95 handler.start().await;
96 });
97}
98
99pub struct EventHandler {
100 pub config: Arc<DistConfig>,
101 pub cluster: Arc<dyn DistCluster>,
102 pub network: Arc<dyn DistNetwork>,
103 pub local_stages: Arc<Mutex<HashMap<StageId, StageState>>>,
104 pub sender: Sender<Event>,
105 pub receiver: Receiver<Event>,
106}
107
108impl EventHandler {
109 pub async fn start(&mut self) {
110 let mut batch = Vec::with_capacity(MAX_BATCH_SIZE);
111 loop {
112 batch.clear();
113 let received = self.receiver.recv_many(&mut batch, MAX_BATCH_SIZE).await;
114 if received == 0 {
115 break;
116 }
117 debug!("Received batch of {received} events, merging duplicates");
118 let merged = merge_events(&mut batch);
119 debug!("Merged into {} events", merged.len());
120 self.handle_events(merged).await;
121 }
122 }
123
124 async fn handle_events(&self, events: Vec<Event>) {
125 let mut check_job_ids = Vec::new();
126 let mut cleanup_job_ids = Vec::new();
127 let mut all_stage0_ids = Vec::new();
128
129 for event in events {
130 debug!("Handling event: {event:?}");
131 match event {
132 Event::CheckJobCompleted(job_id) => check_job_ids.push(job_id),
133 Event::CleanupJob(job_id) => cleanup_job_ids.push(job_id),
134 Event::ReceivedStage0Tasks(stage0_ids) => all_stage0_ids.extend(stage0_ids),
135 }
136 }
137
138 if !check_job_ids.is_empty() {
139 let cluster = self.cluster.clone();
140 let network = self.network.clone();
141 let local_stages = self.local_stages.clone();
142 let sender = self.sender.clone();
143 tokio::spawn(async move {
144 handle_check_jobs_completed(
145 &cluster,
146 &network,
147 &local_stages,
148 &sender,
149 check_job_ids.clone(),
150 )
151 .await;
152 });
153 }
154
155 if !cleanup_job_ids.is_empty() {
156 let cluster = self.cluster.clone();
157 let network = self.network.clone();
158 let local_stages = self.local_stages.clone();
159 tokio::spawn(async move {
160 if let Err(e) =
161 cleanup_jobs(&cluster, &network, &local_stages, cleanup_job_ids.clone()).await
162 {
163 error!("Failed to cleanup jobs {cleanup_job_ids:?}: {e}");
164 }
165 });
166 }
167
168 if !all_stage0_ids.is_empty() {
169 let local_stages = self.local_stages.clone();
170 let stage0_task_poll_timeout = self.config.stage0_task_poll_timeout;
171 let sender = self.sender.clone();
172 tokio::spawn(async move {
173 wait_stage0_tasks_polling(
174 &local_stages,
175 stage0_task_poll_timeout,
176 &sender,
177 all_stage0_ids,
178 )
179 .await
180 });
181 }
182 }
183}
184
185async fn handle_check_jobs_completed(
186 cluster: &Arc<dyn DistCluster>,
187 network: &Arc<dyn DistNetwork>,
188 local_stages: &Arc<Mutex<HashMap<StageId, StageState>>>,
189 sender: &Sender<Event>,
190 job_ids: Vec<JobId>,
191) {
192 match (|| async { check_jobs_completed(cluster, network, local_stages, job_ids.clone()).await })
193 .retry(job_check_retry_strategy())
194 .await
195 {
196 Ok(completed_map) => {
197 for (job_id, completed) in completed_map {
198 if completed {
199 debug!("Job {job_id} completed, remove it from cluster");
200 if let Err(e) =
201 send_event_with_timeout(sender, Event::CleanupJob(job_id.clone())).await
202 {
203 error!("Failed to send cleanup job event for job {job_id}: {e}");
204 }
205 }
206 }
207 }
208 Err(err) => {
209 error!("Failed to check jobs {job_ids:?} completed: {err}");
210 }
211 }
212}
213
214pub async fn check_jobs_completed(
215 cluster: &Arc<dyn DistCluster>,
216 network: &Arc<dyn DistNetwork>,
217 local_stages: &Arc<Mutex<HashMap<StageId, StageState>>>,
218 job_ids: Vec<JobId>,
219) -> DistResult<HashMap<JobId, bool>> {
220 if job_ids.is_empty() {
221 return Ok(HashMap::new());
222 }
223
224 let alive_nodes = cluster
226 .alive_nodes()
227 .await?
228 .keys()
229 .cloned()
230 .collect::<HashSet<_>>();
231
232 let target_nodes_by_job = {
234 let guard = local_stages.lock();
235 job_ids
236 .iter()
237 .cloned()
238 .map(|job_id| {
239 let target_nodes = guard
240 .values()
241 .find(|stage| stage.stage_id.job_id == job_id)
242 .map(|stage| {
243 stage
244 .job_task_distribution
245 .values()
246 .cloned()
247 .collect::<HashSet<_>>()
248 });
249 (job_id, target_nodes)
250 })
251 .collect::<Vec<_>>()
252 };
253
254 let mut completed_map = HashMap::with_capacity(job_ids.len());
255
256 let mut jobs_by_node: HashMap<NodeId, Vec<JobId>> = HashMap::new();
257 for (job_id, target_nodes) in target_nodes_by_job {
258 match target_nodes {
259 Some(nodes) if nodes.is_subset(&alive_nodes) => {
260 for node_id in nodes {
261 jobs_by_node
262 .entry(node_id)
263 .or_default()
264 .push(job_id.clone());
265 }
266 }
267 Some(nodes) => {
268 let missing: Vec<_> = nodes.difference(&alive_nodes).collect();
269 warn!(
270 "Job {job_id} is polluted: task nodes {missing:?} are not alive, treat as completed"
271 );
272 completed_map.insert(job_id, true);
273 }
274 None => {
275 warn!(
276 "No job_task_distribution found for job {job_id}, skipping remote status check"
277 );
278 }
279 }
280 }
281
282 let mut all_job_statuses = HashMap::new();
283
284 if let Some(local_job_ids) = jobs_by_node.remove(&network.local_node()) {
285 let local_job_statuses = local_jobs(local_stages, Some(&local_job_ids));
286 all_job_statuses.extend(local_job_statuses);
287 }
288
289 let mut futures = Vec::new();
290 for (node_id, job_ids) in jobs_by_node {
291 let network = network.clone();
292 futures.push(async move {
293 network
294 .get_jobs(node_id.clone(), Some(job_ids.clone()))
295 .await
296 });
297 }
298
299 for remote_status in join_all(futures).await {
300 let remote_status = remote_status?;
301 for (stage_id, remote_stage_info) in remote_status {
302 all_job_statuses
303 .entry(stage_id)
304 .and_modify(|existing| {
305 existing.merge(&remote_stage_info);
306 })
307 .or_insert(remote_stage_info);
308 }
309 }
310
311 for job_id in job_ids {
312 if completed_map.contains_key(&job_id) {
313 continue;
314 }
315
316 let stage0 = StageId {
317 job_id: job_id.clone(),
318 stage: 0,
319 };
320
321 let job_completed = match all_job_statuses.get(&stage0) {
322 Some(stage0_info) => stage0_info.assigned_partitions.iter().all(|partition| {
323 stage0_info
324 .task_set_infos
325 .iter()
326 .any(|ts| ts.dropped_partitions.contains_key(partition))
327 }),
328 None => true,
329 };
330 completed_map.insert(job_id, job_completed);
331 }
332
333 Ok(completed_map)
334}
335
336pub fn local_jobs(
337 stages: &Arc<Mutex<HashMap<StageId, StageState>>>,
338 job_ids: Option<&Vec<JobId>>,
339) -> HashMap<StageId, StageInfo> {
340 let guard = stages.lock();
341
342 let mut result = HashMap::new();
343 for (stage_id, stage_state) in guard.iter() {
344 if job_ids.is_none_or(|job_ids| job_ids.contains(&stage_id.job_id)) {
345 let stage_info = StageInfo::from_stage_state(stage_state);
346 result.insert(stage_id.clone(), stage_info);
347 }
348 }
349
350 result
351}
352
353pub async fn cleanup_jobs(
354 cluster: &Arc<dyn DistCluster>,
355 network: &Arc<dyn DistNetwork>,
356 local_stages: &Arc<Mutex<HashMap<StageId, StageState>>>,
357 job_ids: Vec<JobId>,
358) -> DistResult<()> {
359 let alive_nodes: HashSet<NodeId> = cluster.alive_nodes().await?.keys().cloned().collect();
360
361 let target_nodes_by_job = {
362 let guard = local_stages.lock();
363 job_ids
364 .iter()
365 .cloned()
366 .map(|job_id| {
367 let target_nodes = guard
368 .values()
369 .find(|stage| stage.stage_id.job_id == job_id)
370 .map(|stage| {
371 stage
372 .job_task_distribution
373 .values()
374 .cloned()
375 .collect::<HashSet<_>>()
376 });
377 (job_id, target_nodes)
378 })
379 .collect::<Vec<_>>()
380 };
381
382 let mut jobs_by_node: HashMap<NodeId, Vec<JobId>> = HashMap::new();
383 for (job_id, target_nodes) in target_nodes_by_job {
384 let nodes_to_clean: HashSet<NodeId> = match target_nodes {
385 Some(nodes) if nodes.is_subset(&alive_nodes) => nodes,
386 Some(nodes) => {
387 let missing: Vec<_> = nodes.difference(&alive_nodes).collect();
388 warn!("Job {job_id} is polluted: task nodes {missing:?} are not alive");
389 nodes
390 .into_iter()
391 .filter(|n| alive_nodes.contains(n))
392 .collect()
393 }
394 None => alive_nodes.clone(),
395 };
396
397 for node_id in nodes_to_clean {
398 jobs_by_node
399 .entry(node_id)
400 .or_default()
401 .push(job_id.clone());
402 }
403 }
404
405 if let Some(local_job_ids) = jobs_by_node.remove(&network.local_node()) {
406 let local_job_ids: HashSet<JobId> = local_job_ids.into_iter().collect();
407 cleanup_stages(&mut local_stages.lock(), |stage_id| {
408 local_job_ids.contains(&stage_id.job_id)
409 });
410 }
411
412 let mut futures = Vec::new();
413 for (node_id, job_ids) in jobs_by_node {
414 if !job_ids.is_empty() {
415 let network = network.clone();
416 futures
417 .push(async move { network.cleanup_jobs(node_id.clone(), job_ids.clone()).await });
418 }
419 }
420
421 for res in join_all(futures).await {
422 res?;
423 }
424 Ok(())
425}
426
427async fn wait_stage0_tasks_polling(
428 local_stages: &Arc<Mutex<HashMap<StageId, StageState>>>,
429 stage0_task_poll_timeout: Duration,
430 sender: &Sender<Event>,
431 stage0_ids: Vec<StageId>,
432) {
433 tokio::time::sleep(stage0_task_poll_timeout).await;
434
435 let mut timeout_job_ids = HashSet::new();
436 {
437 let stages_guard = local_stages.lock();
438 for stage_id in stage0_ids {
439 if let Some(stage) = stages_guard.get(&stage_id)
440 && stage.never_executed()
441 {
442 debug!("Found stage0 {stage_id} never polled until timeout");
443 timeout_job_ids.insert(stage_id.job_id.clone());
444 }
445 }
446 drop(stages_guard);
447 }
448
449 for job_id in timeout_job_ids {
450 if let Err(e) = send_event_with_timeout(sender, Event::CleanupJob(job_id.clone())).await {
451 error!("Failed to send CleanupJob event for job {job_id}: {e}");
452 }
453 }
454}