datafusion_dist/
event.rs

1use std::{collections::HashMap, sync::Arc};
2
3use log::{debug, error};
4use parking_lot::Mutex;
5use tokio::sync::mpsc::{Receiver, Sender};
6use uuid::Uuid;
7
8use crate::{
9    DistResult,
10    cluster::{DistCluster, NodeId},
11    config::DistConfig,
12    network::{DistNetwork, StageInfo},
13    planner::StageId,
14    runtime::StageState,
15};
16
17#[derive(Debug, Clone)]
18pub enum Event {
19    CheckJobCompleted(Uuid),
20    CleanupJob(Uuid),
21    ReceivedStage0Tasks(Vec<StageId>),
22}
23
24pub fn start_event_handler(mut handler: EventHandler) {
25    tokio::spawn(async move {
26        handler.start().await;
27    });
28}
29
30pub struct EventHandler {
31    pub local_node: NodeId,
32    pub config: Arc<DistConfig>,
33    pub cluster: Arc<dyn DistCluster>,
34    pub network: Arc<dyn DistNetwork>,
35    pub local_stages: Arc<Mutex<HashMap<StageId, StageState>>>,
36    pub sender: Sender<Event>,
37    pub receiver: Receiver<Event>,
38}
39
40impl EventHandler {
41    pub async fn start(&mut self) {
42        while let Some(event) = self.receiver.recv().await {
43            debug!("Received event: {event:?}");
44            match event {
45                Event::CheckJobCompleted(job_id) => {
46                    self.handle_check_job_completed(job_id).await;
47                }
48                Event::CleanupJob(job_id) => {
49                    self.handle_cleanup_job(job_id).await;
50                }
51                Event::ReceivedStage0Tasks(stage0_ids) => {
52                    self.handle_received_stage0_tasks(stage0_ids).await;
53                }
54            }
55        }
56    }
57
58    async fn handle_check_job_completed(&mut self, job_id: Uuid) {
59        match check_job_completed(&self.cluster, &self.network, &self.local_stages, job_id).await {
60            Ok(Some(true)) => {
61                debug!("Job {job_id} completed, remove it from cluster");
62
63                if let Err(e) = self.sender.send(Event::CleanupJob(job_id)).await {
64                    error!("Failed to send cleanup job event for job {job_id}: {e}");
65                }
66            }
67            Ok(_) => {}
68            Err(err) => {
69                error!("Failed to check job {job_id} completed: {err}");
70            }
71        }
72    }
73
74    async fn handle_cleanup_job(&mut self, job_id: Uuid) {
75        if let Err(e) = cleanup_job(
76            &self.local_node,
77            &self.cluster,
78            &self.network,
79            &self.local_stages,
80            job_id,
81        )
82        .await
83        {
84            error!("Failed to cleanup job {job_id}: {e}");
85        }
86    }
87
88    async fn handle_received_stage0_tasks(&self, stage0_ids: Vec<StageId>) {
89        let stage0_task_poll_timeout = self.config.stage0_task_poll_timeout;
90        let local_stages = self.local_stages.clone();
91        let sender = self.sender.clone();
92        tokio::spawn(async move {
93            tokio::time::sleep(stage0_task_poll_timeout).await;
94
95            let mut timeout_stage0_id = None;
96            {
97                let stages_guard = local_stages.lock();
98                for stage_id in stage0_ids {
99                    if let Some(stage) = stages_guard.get(&stage_id)
100                        && stage.never_executed()
101                    {
102                        debug!("Found stage0 {stage_id} never polled until timeout");
103                        timeout_stage0_id = Some(stage_id);
104                        break;
105                    }
106                }
107                drop(stages_guard);
108            }
109
110            if let Some(stage_id) = timeout_stage0_id
111                && let Err(e) = sender.send(Event::CleanupJob(stage_id.job_id)).await
112            {
113                error!(
114                    "Failed to send CleanupJob event for job {}: {e}",
115                    stage_id.job_id
116                );
117            }
118        });
119    }
120}
121
122pub async fn check_job_completed(
123    cluster: &Arc<dyn DistCluster>,
124    network: &Arc<dyn DistNetwork>,
125    local_stages: &Arc<Mutex<HashMap<StageId, StageState>>>,
126    job_id: Uuid,
127) -> DistResult<Option<bool>> {
128    // First, get local status
129    let mut combined_status = local_stage_stats(local_stages, Some(job_id));
130
131    // Then, get status from all other alive nodes
132    let node_states = cluster.alive_nodes().await?;
133
134    let local_node_id = network.local_node();
135
136    let mut handles = Vec::new();
137    for node_id in node_states.keys() {
138        if *node_id != local_node_id {
139            let network = network.clone();
140            let node_id = node_id.clone();
141            let handle =
142                tokio::spawn(async move { network.get_job_status(node_id, Some(job_id)).await });
143            handles.push(handle);
144        }
145    }
146
147    for handle in handles {
148        let remote_status = handle.await??;
149        for (stage_id, remote_stage_info) in remote_status {
150            combined_status
151                .entry(stage_id)
152                .and_modify(|existing| {
153                    existing
154                        .assigned_partitions
155                        .extend(&remote_stage_info.assigned_partitions);
156                    existing
157                        .task_set_infos
158                        .extend(remote_stage_info.task_set_infos.clone());
159                })
160                .or_insert(remote_stage_info);
161        }
162    }
163
164    let stage0 = StageId { job_id, stage: 0 };
165
166    let Some(stage0_info) = combined_status.get(&stage0) else {
167        return Ok(None);
168    };
169
170    // Check if all assigned partitions are completed
171    for partition in &stage0_info.assigned_partitions {
172        let is_completed = stage0_info
173            .task_set_infos
174            .iter()
175            .any(|ts| ts.dropped_partitions.contains_key(partition));
176        if !is_completed {
177            return Ok(Some(false));
178        }
179    }
180
181    Ok(Some(true))
182}
183
184pub fn local_stage_stats(
185    stages: &Arc<Mutex<HashMap<StageId, StageState>>>,
186    job_id: Option<Uuid>,
187) -> HashMap<StageId, StageInfo> {
188    let guard = stages.lock();
189
190    let mut result = HashMap::new();
191    for (stage_id, stage_state) in guard.iter() {
192        if job_id.is_none() || stage_id.job_id == job_id.unwrap() {
193            let stage_info = StageInfo::from_stage_state(stage_state);
194            result.insert(*stage_id, stage_info);
195        }
196    }
197
198    result
199}
200
201pub async fn cleanup_job(
202    local_node: &NodeId,
203    cluster: &Arc<dyn DistCluster>,
204    network: &Arc<dyn DistNetwork>,
205    local_stages: &Arc<Mutex<HashMap<StageId, StageState>>>,
206    job_id: Uuid,
207) -> DistResult<()> {
208    let alive_nodes = cluster.alive_nodes().await?;
209
210    for node_id in alive_nodes.keys() {
211        if node_id == local_node {
212            let mut guard = local_stages.lock();
213            guard.retain(|stage_id, _| stage_id.job_id != job_id);
214            drop(guard);
215        } else {
216            // Send cleanup request to remote node
217            network.cleanup_job(node_id.clone(), job_id).await?
218        }
219    }
220    Ok(())
221}