Skip to main content

datafusion_dist/
event.rs

1use std::{
2    collections::{HashMap, HashSet},
3    sync::Arc,
4    time::Duration,
5};
6
7use backon::{ExponentialBuilder, Retryable};
8use futures::future::join_all;
9use log::{debug, error, warn};
10use parking_lot::Mutex;
11use tokio::sync::mpsc::{Receiver, Sender};
12
13use crate::{
14    DistError, DistResult, JobId,
15    cluster::{DistCluster, NodeId},
16    config::DistConfig,
17    network::{DistNetwork, StageInfo},
18    planner::StageId,
19    runtime::{StageState, cleanup_stages},
20};
21
22#[derive(Debug, Clone)]
23pub enum Event {
24    CheckJobCompleted(JobId),
25    CleanupJob(JobId),
26    ReceivedStage0Tasks(Vec<StageId>),
27}
28
29const MAX_BATCH_SIZE: usize = 1024;
30const EVENT_SEND_TIMEOUT: Duration = Duration::from_secs(300);
31const CHECK_JOB_RETRY_MAX_DELAY: Duration = Duration::from_secs(10);
32const CHECK_JOB_RETRY_MAX_TIMES: usize = 3;
33
34fn job_check_retry_strategy() -> ExponentialBuilder {
35    ExponentialBuilder::default()
36        .with_max_delay(CHECK_JOB_RETRY_MAX_DELAY)
37        .with_max_times(CHECK_JOB_RETRY_MAX_TIMES)
38        .with_jitter()
39}
40
41pub async fn send_event_with_timeout(sender: &Sender<Event>, event: Event) -> DistResult<()> {
42    tokio::time::timeout(EVENT_SEND_TIMEOUT, sender.send(event))
43        .await
44        .map_err(|_| {
45            DistError::internal(format!(
46                "Timed out sending event after {}s",
47                EVENT_SEND_TIMEOUT.as_secs()
48            ))
49        })?
50        .map_err(|e| DistError::internal(format!("Failed to send event: {e}")))
51}
52
53/// Merge duplicate events in a batch.
54///
55/// - `CheckJobCompleted(JobId)` and `CleanupJob(JobId)` are deduplicated by
56///   `job_id`: only the first occurrence for a given job is kept.
57/// - All non-empty `ReceivedStage0Tasks(Vec<StageId>)` are concatenated into a
58///   single event, preserving batch order.
59/// - Empty `ReceivedStage0Tasks` vectors are silently skipped.
60fn merge_events(events: &mut Vec<Event>) -> Vec<Event> {
61    let mut merged: Vec<Event> = Vec::with_capacity(events.len());
62    let mut seen_check_jobs = HashSet::with_capacity(events.len());
63    let mut seen_cleanup_jobs = HashSet::with_capacity(events.len());
64    let mut stage0_ids = Vec::new();
65
66    for event in events.drain(..) {
67        match event {
68            Event::CheckJobCompleted(job_id) => {
69                if seen_check_jobs.insert(job_id.clone()) {
70                    merged.push(Event::CheckJobCompleted(job_id));
71                }
72            }
73            Event::CleanupJob(job_id) => {
74                if seen_cleanup_jobs.insert(job_id.clone()) {
75                    merged.push(Event::CleanupJob(job_id));
76                }
77            }
78            Event::ReceivedStage0Tasks(mut ids) => {
79                if !ids.is_empty() {
80                    stage0_ids.append(&mut ids);
81                }
82            }
83        }
84    }
85
86    if !stage0_ids.is_empty() {
87        merged.push(Event::ReceivedStage0Tasks(stage0_ids));
88    }
89
90    merged
91}
92
93pub fn start_event_handler(mut handler: EventHandler) {
94    tokio::spawn(async move {
95        handler.start().await;
96    });
97}
98
99pub struct EventHandler {
100    pub config: Arc<DistConfig>,
101    pub cluster: Arc<dyn DistCluster>,
102    pub network: Arc<dyn DistNetwork>,
103    pub local_stages: Arc<Mutex<HashMap<StageId, StageState>>>,
104    pub sender: Sender<Event>,
105    pub receiver: Receiver<Event>,
106}
107
108impl EventHandler {
109    pub async fn start(&mut self) {
110        let mut batch = Vec::with_capacity(MAX_BATCH_SIZE);
111        loop {
112            batch.clear();
113            let received = self.receiver.recv_many(&mut batch, MAX_BATCH_SIZE).await;
114            if received == 0 {
115                break;
116            }
117            debug!("Received batch of {received} events, merging duplicates");
118            let merged = merge_events(&mut batch);
119            debug!("Merged into {} events", merged.len());
120            self.handle_events(merged).await;
121        }
122    }
123
124    async fn handle_events(&self, events: Vec<Event>) {
125        let mut check_job_ids = Vec::new();
126        let mut cleanup_job_ids = Vec::new();
127        let mut all_stage0_ids = Vec::new();
128
129        for event in events {
130            debug!("Handling event: {event:?}");
131            match event {
132                Event::CheckJobCompleted(job_id) => check_job_ids.push(job_id),
133                Event::CleanupJob(job_id) => cleanup_job_ids.push(job_id),
134                Event::ReceivedStage0Tasks(stage0_ids) => all_stage0_ids.extend(stage0_ids),
135            }
136        }
137
138        if !check_job_ids.is_empty() {
139            let cluster = self.cluster.clone();
140            let network = self.network.clone();
141            let local_stages = self.local_stages.clone();
142            let sender = self.sender.clone();
143            tokio::spawn(async move {
144                handle_check_jobs_completed(
145                    &cluster,
146                    &network,
147                    &local_stages,
148                    &sender,
149                    check_job_ids.clone(),
150                )
151                .await;
152            });
153        }
154
155        if !cleanup_job_ids.is_empty() {
156            let cluster = self.cluster.clone();
157            let network = self.network.clone();
158            let local_stages = self.local_stages.clone();
159            tokio::spawn(async move {
160                if let Err(e) =
161                    cleanup_jobs(&cluster, &network, &local_stages, cleanup_job_ids.clone()).await
162                {
163                    error!("Failed to cleanup jobs {cleanup_job_ids:?}: {e}");
164                }
165            });
166        }
167
168        if !all_stage0_ids.is_empty() {
169            let local_stages = self.local_stages.clone();
170            let stage0_task_poll_timeout = self.config.stage0_task_poll_timeout;
171            let sender = self.sender.clone();
172            tokio::spawn(async move {
173                wait_stage0_tasks_polling(
174                    &local_stages,
175                    stage0_task_poll_timeout,
176                    &sender,
177                    all_stage0_ids,
178                )
179                .await
180            });
181        }
182    }
183}
184
185async fn handle_check_jobs_completed(
186    cluster: &Arc<dyn DistCluster>,
187    network: &Arc<dyn DistNetwork>,
188    local_stages: &Arc<Mutex<HashMap<StageId, StageState>>>,
189    sender: &Sender<Event>,
190    job_ids: Vec<JobId>,
191) {
192    match (|| async { check_jobs_completed(cluster, network, local_stages, job_ids.clone()).await })
193        .retry(job_check_retry_strategy())
194        .await
195    {
196        Ok(completed_map) => {
197            for (job_id, completed) in completed_map {
198                if completed {
199                    debug!("Job {job_id} completed, remove it from cluster");
200                    if let Err(e) =
201                        send_event_with_timeout(sender, Event::CleanupJob(job_id.clone())).await
202                    {
203                        error!("Failed to send cleanup job event for job {job_id}: {e}");
204                    }
205                }
206            }
207        }
208        Err(err) => {
209            error!("Failed to check jobs {job_ids:?} completed: {err}");
210        }
211    }
212}
213
214pub async fn check_jobs_completed(
215    cluster: &Arc<dyn DistCluster>,
216    network: &Arc<dyn DistNetwork>,
217    local_stages: &Arc<Mutex<HashMap<StageId, StageState>>>,
218    job_ids: Vec<JobId>,
219) -> DistResult<HashMap<JobId, bool>> {
220    if job_ids.is_empty() {
221        return Ok(HashMap::new());
222    }
223
224    // Get alive nodes for validation
225    let alive_nodes = cluster
226        .alive_nodes()
227        .await?
228        .keys()
229        .cloned()
230        .collect::<HashSet<_>>();
231
232    // Determine target nodes from job_task_distribution if available
233    let target_nodes_by_job = {
234        let guard = local_stages.lock();
235        job_ids
236            .iter()
237            .cloned()
238            .map(|job_id| {
239                let target_nodes = guard
240                    .values()
241                    .find(|stage| stage.stage_id.job_id == job_id)
242                    .map(|stage| {
243                        stage
244                            .job_task_distribution
245                            .values()
246                            .cloned()
247                            .collect::<HashSet<_>>()
248                    });
249                (job_id, target_nodes)
250            })
251            .collect::<Vec<_>>()
252    };
253
254    let mut completed_map = HashMap::with_capacity(job_ids.len());
255
256    let mut jobs_by_node: HashMap<NodeId, Vec<JobId>> = HashMap::new();
257    for (job_id, target_nodes) in target_nodes_by_job {
258        match target_nodes {
259            Some(nodes) if nodes.is_subset(&alive_nodes) => {
260                for node_id in nodes {
261                    jobs_by_node
262                        .entry(node_id)
263                        .or_default()
264                        .push(job_id.clone());
265                }
266            }
267            Some(nodes) => {
268                let missing: Vec<_> = nodes.difference(&alive_nodes).collect();
269                warn!(
270                    "Job {job_id} is polluted: task nodes {missing:?} are not alive, treat as completed"
271                );
272                completed_map.insert(job_id, true);
273            }
274            None => {
275                warn!(
276                    "No job_task_distribution found for job {job_id}, skipping remote status check"
277                );
278            }
279        }
280    }
281
282    let mut all_job_statuses = HashMap::new();
283
284    if let Some(local_job_ids) = jobs_by_node.remove(&network.local_node()) {
285        let local_job_statuses = local_jobs(local_stages, Some(&local_job_ids));
286        all_job_statuses.extend(local_job_statuses);
287    }
288
289    let mut futures = Vec::new();
290    for (node_id, job_ids) in jobs_by_node {
291        let network = network.clone();
292        futures.push(async move {
293            network
294                .get_jobs(node_id.clone(), Some(job_ids.clone()))
295                .await
296        });
297    }
298
299    for remote_status in join_all(futures).await {
300        let remote_status = remote_status?;
301        for (stage_id, remote_stage_info) in remote_status {
302            all_job_statuses
303                .entry(stage_id)
304                .and_modify(|existing| {
305                    existing.merge(&remote_stage_info);
306                })
307                .or_insert(remote_stage_info);
308        }
309    }
310
311    for job_id in job_ids {
312        if completed_map.contains_key(&job_id) {
313            continue;
314        }
315
316        let stage0 = StageId {
317            job_id: job_id.clone(),
318            stage: 0,
319        };
320
321        let job_completed = match all_job_statuses.get(&stage0) {
322            Some(stage0_info) => stage0_info.assigned_partitions.iter().all(|partition| {
323                stage0_info
324                    .task_set_infos
325                    .iter()
326                    .any(|ts| ts.dropped_partitions.contains_key(partition))
327            }),
328            None => true,
329        };
330        completed_map.insert(job_id, job_completed);
331    }
332
333    Ok(completed_map)
334}
335
336pub fn local_jobs(
337    stages: &Arc<Mutex<HashMap<StageId, StageState>>>,
338    job_ids: Option<&Vec<JobId>>,
339) -> HashMap<StageId, StageInfo> {
340    let guard = stages.lock();
341
342    let mut result = HashMap::new();
343    for (stage_id, stage_state) in guard.iter() {
344        if job_ids.is_none_or(|job_ids| job_ids.contains(&stage_id.job_id)) {
345            let stage_info = StageInfo::from_stage_state(stage_state);
346            result.insert(stage_id.clone(), stage_info);
347        }
348    }
349
350    result
351}
352
353pub async fn cleanup_jobs(
354    cluster: &Arc<dyn DistCluster>,
355    network: &Arc<dyn DistNetwork>,
356    local_stages: &Arc<Mutex<HashMap<StageId, StageState>>>,
357    job_ids: Vec<JobId>,
358) -> DistResult<()> {
359    let alive_nodes: HashSet<NodeId> = cluster.alive_nodes().await?.keys().cloned().collect();
360
361    let target_nodes_by_job = {
362        let guard = local_stages.lock();
363        job_ids
364            .iter()
365            .cloned()
366            .map(|job_id| {
367                let target_nodes = guard
368                    .values()
369                    .find(|stage| stage.stage_id.job_id == job_id)
370                    .map(|stage| {
371                        stage
372                            .job_task_distribution
373                            .values()
374                            .cloned()
375                            .collect::<HashSet<_>>()
376                    });
377                (job_id, target_nodes)
378            })
379            .collect::<Vec<_>>()
380    };
381
382    let mut jobs_by_node: HashMap<NodeId, Vec<JobId>> = HashMap::new();
383    for (job_id, target_nodes) in target_nodes_by_job {
384        let nodes_to_clean: HashSet<NodeId> = match target_nodes {
385            Some(nodes) if nodes.is_subset(&alive_nodes) => nodes,
386            Some(nodes) => {
387                let missing: Vec<_> = nodes.difference(&alive_nodes).collect();
388                warn!("Job {job_id} is polluted: task nodes {missing:?} are not alive");
389                nodes
390                    .into_iter()
391                    .filter(|n| alive_nodes.contains(n))
392                    .collect()
393            }
394            None => alive_nodes.clone(),
395        };
396
397        for node_id in nodes_to_clean {
398            jobs_by_node
399                .entry(node_id)
400                .or_default()
401                .push(job_id.clone());
402        }
403    }
404
405    if let Some(local_job_ids) = jobs_by_node.remove(&network.local_node()) {
406        let local_job_ids: HashSet<JobId> = local_job_ids.into_iter().collect();
407        cleanup_stages(&mut local_stages.lock(), |stage_id| {
408            local_job_ids.contains(&stage_id.job_id)
409        });
410    }
411
412    let mut futures = Vec::new();
413    for (node_id, job_ids) in jobs_by_node {
414        if !job_ids.is_empty() {
415            let network = network.clone();
416            futures
417                .push(async move { network.cleanup_jobs(node_id.clone(), job_ids.clone()).await });
418        }
419    }
420
421    for res in join_all(futures).await {
422        res?;
423    }
424    Ok(())
425}
426
427async fn wait_stage0_tasks_polling(
428    local_stages: &Arc<Mutex<HashMap<StageId, StageState>>>,
429    stage0_task_poll_timeout: Duration,
430    sender: &Sender<Event>,
431    stage0_ids: Vec<StageId>,
432) {
433    tokio::time::sleep(stage0_task_poll_timeout).await;
434
435    let mut timeout_job_ids = HashSet::new();
436    {
437        let stages_guard = local_stages.lock();
438        for stage_id in stage0_ids {
439            if let Some(stage) = stages_guard.get(&stage_id)
440                && stage.never_executed()
441            {
442                debug!("Found stage0 {stage_id} never polled until timeout");
443                timeout_job_ids.insert(stage_id.job_id.clone());
444            }
445        }
446        drop(stages_guard);
447    }
448
449    for job_id in timeout_job_ids {
450        if let Err(e) = send_event_with_timeout(sender, Event::CleanupJob(job_id.clone())).await {
451            error!("Failed to send CleanupJob event for job {job_id}: {e}");
452        }
453    }
454}