datafusion_dist/
runtime.rs

1use std::{
2    collections::{HashMap, HashSet},
3    pin::Pin,
4    sync::Arc,
5    task::{Context, Poll},
6};
7
8use parking_lot::Mutex;
9use serde::{Deserialize, Serialize};
10use uuid::Uuid;
11
12use arrow::array::RecordBatch;
13use datafusion_execution::TaskContext;
14use datafusion_physical_plan::ExecutionPlan;
15
16use futures::{Stream, StreamExt};
17use log::{debug, error};
18use tokio::sync::mpsc::Sender;
19
20use crate::{
21    DistError, DistResult, RecordBatchStream,
22    cluster::{DistCluster, NodeId, NodeStatus},
23    config::DistConfig,
24    event::{Event, EventHandler, cleanup_job, local_stage_stats, start_event_handler},
25    executor::{DefaultExecutor, DistExecutor, logging_executor_metrics},
26    heartbeat::Heartbeater,
27    network::{DistNetwork, ScheduledTasks, StageInfo},
28    planner::{
29        DefaultPlanner, DisplayableStagePlans, DistPlanner, StageId, TaskId,
30        check_initial_stage_plans, resolve_stage_plan,
31    },
32    scheduler::{DefaultScheduler, DisplayableTaskDistribution, DistScheduler},
33    util::{ReceiverStreamBuilder, timestamp_ms},
34};
35
36#[derive(Debug, Clone)]
37pub struct DistRuntime {
38    pub node_id: NodeId,
39    pub status: Arc<Mutex<NodeStatus>>,
40    pub task_ctx: Arc<TaskContext>,
41    pub config: Arc<DistConfig>,
42    pub cluster: Arc<dyn DistCluster>,
43    pub network: Arc<dyn DistNetwork>,
44    pub planner: Arc<dyn DistPlanner>,
45    pub scheduler: Arc<dyn DistScheduler>,
46    pub executor: Arc<dyn DistExecutor>,
47    pub heartbeater: Arc<Heartbeater>,
48    pub stages: Arc<Mutex<HashMap<StageId, StageState>>>,
49    pub event_sender: Sender<Event>,
50}
51
52impl DistRuntime {
53    pub fn new(
54        task_ctx: Arc<TaskContext>,
55        config: Arc<DistConfig>,
56        cluster: Arc<dyn DistCluster>,
57        network: Arc<dyn DistNetwork>,
58    ) -> Self {
59        let node_id = network.local_node();
60        let status = Arc::new(Mutex::new(NodeStatus::Available));
61        let stages = Arc::new(Mutex::new(HashMap::new()));
62        let heartbeater = Heartbeater {
63            node_id: node_id.clone(),
64            cluster: cluster.clone(),
65            stages: stages.clone(),
66            heartbeat_interval: config.heartbeat_interval,
67            status: status.clone(),
68        };
69
70        let (sender, receiver) = tokio::sync::mpsc::channel::<Event>(1024);
71
72        let event_handler = EventHandler {
73            local_node: node_id.clone(),
74            config: config.clone(),
75            cluster: cluster.clone(),
76            network: network.clone(),
77            local_stages: stages.clone(),
78            sender: sender.clone(),
79            receiver,
80        };
81        start_event_handler(event_handler);
82
83        Self {
84            node_id: network.local_node(),
85            status,
86            task_ctx,
87            config,
88            cluster,
89            network,
90            planner: Arc::new(DefaultPlanner),
91            scheduler: Arc::new(DefaultScheduler::new()),
92            executor: Arc::new(DefaultExecutor::new()),
93            heartbeater: Arc::new(heartbeater),
94            stages,
95            event_sender: sender,
96        }
97    }
98
99    pub fn with_planner(self, planner: Arc<dyn DistPlanner>) -> Self {
100        Self { planner, ..self }
101    }
102
103    pub fn with_scheduler(self, scheduler: Arc<dyn DistScheduler>) -> Self {
104        Self { scheduler, ..self }
105    }
106
107    pub fn with_executor(self, executor: Arc<dyn DistExecutor>) -> Self {
108        Self { executor, ..self }
109    }
110
111    pub async fn start(&self) {
112        self.heartbeater.start();
113        start_job_cleaner(self.stages.clone(), self.config.clone());
114    }
115
116    pub async fn shutdown(&self) {
117        // Set status to Terminating
118        *self.status.lock() = NodeStatus::Terminating;
119        debug!("Set node status to Terminating, no new tasks will be assigned");
120
121        self.heartbeater.send_heartbeat().await;
122    }
123
124    pub async fn submit(
125        &self,
126        plan: Arc<dyn ExecutionPlan>,
127    ) -> DistResult<(Uuid, HashMap<TaskId, NodeId>)> {
128        let job_id = Uuid::new_v4();
129        let mut stage_plans = self.planner.plan_stages(job_id, plan)?;
130        debug!(
131            "job {job_id} initial stage plans:\n{}",
132            DisplayableStagePlans(&stage_plans)
133        );
134        check_initial_stage_plans(job_id, &stage_plans)?;
135
136        let node_states = self.cluster.alive_nodes().await?;
137        debug!(
138            "alive nodes: {}",
139            node_states
140                .keys()
141                .map(|n| n.to_string())
142                .collect::<Vec<_>>()
143                .join(", ")
144        );
145
146        let task_distribution = self
147            .scheduler
148            .schedule(&self.node_id, &node_states, &stage_plans)
149            .await?;
150        debug!(
151            "job {job_id} task distribution: {}",
152            DisplayableTaskDistribution(&task_distribution)
153        );
154        let stage0_task_distribution: HashMap<TaskId, NodeId> = task_distribution
155            .iter()
156            .filter(|(task_id, _)| task_id.stage == 0)
157            .map(|(task_id, node_id)| (*task_id, node_id.clone()))
158            .collect();
159        if stage0_task_distribution.is_empty() {
160            return Err(DistError::internal(format!(
161                "Not found stage0 task distribution in {task_distribution:?} for job {job_id}"
162            )));
163        }
164
165        // Resolve stage plans based on task distribution
166        for (_, stage_plan) in stage_plans.iter_mut() {
167            *stage_plan = resolve_stage_plan(stage_plan.clone(), &task_distribution, self.clone())?;
168        }
169        debug!(
170            "job {job_id} final stage plans:\n{}",
171            DisplayableStagePlans(&stage_plans)
172        );
173
174        let mut node_stages = HashMap::new();
175        let mut node_tasks = HashMap::new();
176        for (task_id, node_id) in task_distribution.iter() {
177            node_stages
178                .entry(node_id.clone())
179                .or_insert_with(HashSet::new)
180                .insert(task_id.stage_id());
181            node_tasks
182                .entry(node_id.clone())
183                .or_insert_with(Vec::new)
184                .push(*task_id);
185        }
186
187        // Send stage plans to cluster nodes
188        let mut handles = Vec::with_capacity(node_stages.len());
189        for (node_id, stage_ids) in node_stages {
190            let node_stage_plans = stage_ids
191                .iter()
192                .map(|stage_id| {
193                    (
194                        *stage_id,
195                        stage_plans
196                            .get(stage_id)
197                            .cloned()
198                            .expect("stage id should be valid"),
199                    )
200                })
201                .collect::<HashMap<_, _>>();
202
203            let tasks = node_tasks.get(&node_id).cloned().unwrap_or_default();
204
205            let scheduled_tasks =
206                ScheduledTasks::new(node_stage_plans, tasks, Arc::new(task_distribution.clone()));
207
208            if node_id == self.node_id {
209                self.receive_tasks(scheduled_tasks).await?;
210            } else {
211                debug!(
212                    "Sending job {job_id} tasks [{}] to {node_id}",
213                    scheduled_tasks
214                        .task_ids
215                        .iter()
216                        .map(|t| format!("{}/{}", t.stage, t.partition))
217                        .collect::<Vec<String>>()
218                        .join(", ")
219                );
220                let network = self.network.clone();
221                let handle = tokio::spawn(async move {
222                    network.send_tasks(node_id.clone(), scheduled_tasks).await?;
223                    Ok::<_, DistError>(())
224                });
225                handles.push(handle);
226            }
227        }
228
229        for handle in handles {
230            handle.await??;
231        }
232
233        logging_executor_metrics(self.executor.handle());
234
235        Ok((job_id, stage0_task_distribution))
236    }
237
238    pub async fn execute_local(&self, task_id: TaskId) -> DistResult<RecordBatchStream> {
239        let stage_id = task_id.stage_id();
240
241        let mut guard = self.stages.lock();
242        let stage_state = guard
243            .get_mut(&stage_id)
244            .ok_or_else(|| DistError::internal(format!("Stage {stage_id} not found")))?;
245        let (task_set_id, plan) = stage_state.get_plan(task_id.partition as usize)?;
246        drop(guard);
247
248        let mut receiver_stream_builder = ReceiverStreamBuilder::new(2);
249
250        let tx = receiver_stream_builder.tx();
251        let partition = task_id.partition as usize;
252        let task_ctx = self.task_ctx.clone();
253        let driver_task = async move {
254            let mut df_stream = plan.execute(partition, task_ctx)?;
255
256            while let Some(batch) = df_stream.next().await {
257                let batch = batch.map_err(DistError::from);
258                if tx.send(batch).await.is_err() {
259                    // error means dropped receiver, so nothing will get results anymore
260                    return Ok(());
261                }
262            }
263            Ok(()) as DistResult<()>
264        };
265
266        receiver_stream_builder.spawn_on(driver_task, self.executor.handle());
267
268        let stream = receiver_stream_builder.build();
269
270        let task_stream = TaskStream::new(
271            task_id,
272            task_set_id,
273            self.stages.clone(),
274            self.event_sender.clone(),
275            stream,
276        );
277
278        Ok(task_stream.boxed())
279    }
280
281    pub async fn execute_remote(
282        &self,
283        node_id: NodeId,
284        task_id: TaskId,
285    ) -> DistResult<RecordBatchStream> {
286        if node_id == self.node_id {
287            return Err(DistError::internal(format!(
288                "remote node id {node_id} is actually self"
289            )));
290        }
291
292        debug!("Executing remote task {task_id} on node {node_id}");
293        self.network.execute_task(node_id, task_id).await
294    }
295
296    pub async fn receive_tasks(&self, scheduled_tasks: ScheduledTasks) -> DistResult<()> {
297        debug!(
298            "Received job {} tasks: [{}] and plans of stages: [{}]",
299            scheduled_tasks.job_id()?,
300            scheduled_tasks
301                .task_ids
302                .iter()
303                .map(|t| format!("{}/{}", t.stage, t.partition))
304                .collect::<Vec<String>>()
305                .join(", "),
306            scheduled_tasks
307                .stage_plans
308                .keys()
309                .map(|k| k.stage.to_string())
310                .collect::<Vec<String>>()
311                .join(", ")
312        );
313
314        let stage_states = StageState::from_scheduled_tasks(scheduled_tasks)?;
315        let stage_ids = stage_states.keys().cloned().collect::<Vec<StageId>>();
316        {
317            let mut guard = self.stages.lock();
318            guard.extend(stage_states);
319            drop(guard);
320        }
321
322        let stage0_ids = stage_ids
323            .iter()
324            .filter(|id| id.stage == 0)
325            .cloned()
326            .collect::<Vec<StageId>>();
327        if !stage0_ids.is_empty() {
328            self.event_sender
329                .send(Event::ReceivedStage0Tasks(stage0_ids))
330                .await
331                .map_err(|e| {
332                    DistError::internal(format!("Failed to send ReceivedStage0Tasks event: {e}"))
333                })?;
334        }
335
336        Ok(())
337    }
338
339    pub fn cleanup_local_job(&self, job_id: Uuid) {
340        let mut guard = self.stages.lock();
341        let stage_ids = guard
342            .iter()
343            .filter(|(stage_id, _)| stage_id.job_id == job_id)
344            .map(|(stage_id, _)| stage_id)
345            .collect::<Vec<_>>();
346        if !stage_ids.is_empty() {
347            debug!(
348                "Cleaning up local Job {job_id} stages [{}]",
349                stage_ids
350                    .iter()
351                    .map(|id| id.stage.to_string())
352                    .collect::<Vec<_>>()
353                    .join(", ")
354            );
355            guard.retain(|stage_id, _| stage_id.job_id != job_id);
356        }
357    }
358
359    pub async fn cleanup_job(&self, job_id: Uuid) -> DistResult<()> {
360        cleanup_job(
361            &self.node_id,
362            &self.cluster,
363            &self.network,
364            &self.stages,
365            job_id,
366        )
367        .await
368    }
369
370    pub fn get_local_job(&self, job_id: Uuid) -> HashMap<StageId, StageInfo> {
371        local_stage_stats(&self.stages, Some(job_id))
372    }
373
374    pub fn get_local_jobs(&self) -> HashMap<Uuid, HashMap<StageId, StageInfo>> {
375        let stage_stat = local_stage_stats(&self.stages, None);
376
377        // Aggregate stats by job_id
378        let mut job_stats: HashMap<Uuid, HashMap<StageId, StageInfo>> = HashMap::new();
379        for (stage_id, stage_info) in stage_stat {
380            job_stats
381                .entry(stage_id.job_id)
382                .or_default()
383                .insert(stage_id, stage_info);
384        }
385
386        job_stats
387    }
388
389    pub async fn get_all_jobs(&self) -> DistResult<HashMap<Uuid, HashMap<StageId, StageInfo>>> {
390        // First, get local status for all jobs
391        let mut combined_status = local_stage_stats(&self.stages, None);
392
393        // Then, get status from all other alive nodes
394        let node_states = self.cluster.alive_nodes().await?;
395
396        let mut handles = Vec::new();
397        for node_id in node_states.keys() {
398            if *node_id != self.node_id {
399                let network = self.network.clone();
400                let node_id = node_id.clone();
401                let handle =
402                    tokio::spawn(async move { network.get_job_status(node_id, None).await });
403                handles.push(handle);
404            }
405        }
406
407        for handle in handles {
408            let remote_status = handle.await??;
409            for (stage_id, remote_stage_info) in remote_status {
410                combined_status
411                    .entry(stage_id)
412                    .and_modify(|existing| {
413                        existing
414                            .assigned_partitions
415                            .extend(&remote_stage_info.assigned_partitions);
416                        existing
417                            .task_set_infos
418                            .extend(remote_stage_info.task_set_infos.clone());
419                    })
420                    .or_insert(remote_stage_info);
421            }
422        }
423
424        // Aggregate stats by job_id
425        let mut job_stats: HashMap<Uuid, HashMap<StageId, StageInfo>> = HashMap::new();
426        for (stage_id, stage_info) in combined_status {
427            job_stats
428                .entry(stage_id.job_id)
429                .or_default()
430                .insert(stage_id, stage_info);
431        }
432
433        Ok(job_stats)
434    }
435}
436
437#[derive(Debug)]
438pub struct StageState {
439    pub stage_id: StageId,
440    pub create_at_ms: i64,
441    pub stage_plan: Arc<dyn ExecutionPlan>,
442    pub assigned_partitions: HashSet<usize>,
443    pub task_sets: Vec<TaskSet>,
444    pub job_task_distribution: Arc<HashMap<TaskId, NodeId>>,
445}
446
447impl StageState {
448    pub fn from_scheduled_tasks(
449        scheduled_tasks: ScheduledTasks,
450    ) -> DistResult<HashMap<StageId, StageState>> {
451        let mut stage_tasks: HashMap<StageId, HashSet<TaskId>> = HashMap::new();
452        for task_id in scheduled_tasks.task_ids {
453            let stage_id = task_id.stage_id();
454            stage_tasks.entry(stage_id).or_default().insert(task_id);
455        }
456
457        let mut stage_states = HashMap::new();
458        for (stage_id, assigned_task_ids) in stage_tasks {
459            let stage_state = StageState {
460                stage_id,
461                create_at_ms: timestamp_ms(),
462                stage_plan: scheduled_tasks
463                    .stage_plans
464                    .get(&stage_id)
465                    .ok_or_else(|| {
466                        DistError::internal(format!("Not found plan of stage {stage_id}"))
467                    })?
468                    .clone(),
469                assigned_partitions: assigned_task_ids
470                    .iter()
471                    .map(|task_id| task_id.partition as usize)
472                    .collect(),
473                task_sets: Vec::new(),
474                job_task_distribution: scheduled_tasks.job_task_distribution.clone(),
475            };
476            stage_states.insert(stage_id, stage_state);
477        }
478        Ok(stage_states)
479    }
480
481    pub fn num_running_tasks(&self) -> usize {
482        self.task_sets
483            .iter()
484            .map(|task_set| task_set.running_partitions.len())
485            .sum()
486    }
487
488    pub fn get_plan(&mut self, partition: usize) -> DistResult<(Uuid, Arc<dyn ExecutionPlan>)> {
489        if !self.assigned_partitions.contains(&partition) {
490            let task_id = self.stage_id.task_id(partition as u32);
491            return Err(DistError::internal(format!(
492                "Task {task_id} not found in this node"
493            )));
494        }
495
496        for task_set in self.task_sets.iter_mut() {
497            if !task_set.never_executed(&partition) {
498                task_set.running_partitions.insert(partition);
499                return Ok((task_set.id, task_set.shared_plan.clone()));
500            }
501        }
502
503        let task_set_id = Uuid::new_v4();
504        let mut new_task_set = TaskSet {
505            id: task_set_id,
506            shared_plan: self.stage_plan.clone().reset_state()?,
507            running_partitions: HashSet::new(),
508            dropped_partitions: HashMap::new(),
509        };
510        new_task_set.running_partitions.insert(partition);
511        let shared_plan = new_task_set.shared_plan.clone();
512        self.task_sets.push(new_task_set);
513
514        Ok((task_set_id, shared_plan))
515    }
516
517    pub fn complete_task(&mut self, task_id: TaskId, task_set_id: Uuid, task_metrics: TaskMetrics) {
518        if let Some(task_set) = self
519            .task_sets
520            .iter_mut()
521            .find(|task_set| task_set.id == task_set_id)
522        {
523            task_set
524                .running_partitions
525                .remove(&(task_id.partition as usize));
526            task_set
527                .dropped_partitions
528                .insert(task_id.partition as usize, task_metrics);
529        }
530    }
531
532    pub fn assigned_partitions_executed_at_least_once(&self) -> bool {
533        let executed_partitions = self
534            .task_sets
535            .iter()
536            .flat_map(|task_set| {
537                let mut executed = task_set.running_partitions.clone();
538                executed.extend(task_set.dropped_partitions.keys());
539                executed
540            })
541            .collect::<HashSet<_>>();
542
543        for partition in self.assigned_partitions.iter() {
544            if !executed_partitions.contains(partition) {
545                return false;
546            }
547        }
548        true
549    }
550
551    pub fn never_executed(&self) -> bool {
552        self.task_sets
553            .iter()
554            .all(|set| set.running_partitions.is_empty() && set.dropped_partitions.is_empty())
555    }
556}
557
558#[derive(Debug)]
559pub struct TaskSet {
560    pub id: Uuid,
561    pub shared_plan: Arc<dyn ExecutionPlan>,
562    pub running_partitions: HashSet<usize>,
563    pub dropped_partitions: HashMap<usize, TaskMetrics>,
564}
565
566impl TaskSet {
567    pub fn never_executed(&self, partition: &usize) -> bool {
568        !self.running_partitions.contains(partition)
569            && !self.dropped_partitions.contains_key(partition)
570    }
571}
572
573#[derive(Debug, Clone, Serialize, Deserialize)]
574pub struct TaskMetrics {
575    pub output_rows: usize,
576    pub output_bytes: usize,
577    pub completed: bool,
578}
579
580pub struct TaskStream {
581    pub task_id: TaskId,
582    pub task_set_id: Uuid,
583    pub stages: Arc<Mutex<HashMap<StageId, StageState>>>,
584    pub event_sender: Sender<Event>,
585    pub stream: RecordBatchStream,
586    pub output_rows: usize,
587    pub output_bytes: usize,
588    pub completed: bool,
589}
590
591impl TaskStream {
592    pub fn new(
593        task_id: TaskId,
594        task_set_id: Uuid,
595        stages: Arc<Mutex<HashMap<StageId, StageState>>>,
596        event_sender: Sender<Event>,
597        stream: RecordBatchStream,
598    ) -> Self {
599        Self {
600            task_id,
601            task_set_id,
602            stages,
603            event_sender,
604            stream,
605            output_rows: 0,
606            output_bytes: 0,
607            completed: false,
608        }
609    }
610}
611
612impl Stream for TaskStream {
613    type Item = DistResult<RecordBatch>;
614
615    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
616        match self.stream.as_mut().poll_next(cx) {
617            Poll::Ready(Some(Ok(batch))) => {
618                self.output_rows += batch.num_rows();
619                self.output_bytes += batch.get_array_memory_size();
620                Poll::Ready(Some(Ok(batch)))
621            }
622            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
623            Poll::Ready(None) => {
624                self.completed = true;
625                Poll::Ready(None)
626            }
627            Poll::Pending => Poll::Pending,
628        }
629    }
630}
631
632impl Drop for TaskStream {
633    fn drop(&mut self) {
634        let task_id = self.task_id;
635        let task_set_id = self.task_set_id;
636        let task_metrics = TaskMetrics {
637            output_bytes: self.output_bytes,
638            output_rows: self.output_rows,
639            completed: self.completed,
640        };
641        debug!("Task {task_id} dropped with metrics: {task_metrics:?}");
642
643        let stages = self.stages.clone();
644        let sender = self.event_sender.clone();
645        tokio::spawn(async move {
646            let mut send_event = false;
647            {
648                let mut guard = stages.lock();
649                if let Some(stage_state) = guard.get_mut(&task_id.stage_id()) {
650                    stage_state.complete_task(task_id, task_set_id, task_metrics);
651                    if stage_state.stage_id.stage == 0
652                        && stage_state.assigned_partitions_executed_at_least_once()
653                    {
654                        send_event = true;
655                    }
656                }
657            }
658
659            if send_event
660                && let Err(e) = sender.send(Event::CheckJobCompleted(task_id.job_id)).await
661            {
662                error!(
663                    "Failed to send CheckJobCompleted event after task {task_id} stream dropped: {e}"
664                );
665            }
666        });
667    }
668}
669
670fn start_job_cleaner(stages: Arc<Mutex<HashMap<StageId, StageState>>>, config: Arc<DistConfig>) {
671    tokio::spawn(async move {
672        loop {
673            tokio::time::sleep(config.job_ttl_check_interval).await;
674
675            let mut guard = stages.lock();
676            let mut to_cleanup = Vec::new();
677            for (stage_id, stage_state) in guard.iter() {
678                let age_ms = timestamp_ms() - stage_state.create_at_ms;
679                if age_ms >= config.job_ttl.as_millis() as i64 {
680                    to_cleanup.push(*stage_id);
681                }
682            }
683
684            if !to_cleanup.is_empty() {
685                debug!(
686                    "Stages [{}] lifetime exceed job ttl {}s, cleaning up.",
687                    to_cleanup
688                        .iter()
689                        .map(|id| id.to_string())
690                        .collect::<Vec<_>>()
691                        .join(", "),
692                    config.job_ttl.as_secs()
693                );
694                guard.retain(|stage_id, _| !to_cleanup.contains(stage_id));
695            }
696            drop(guard);
697        }
698    });
699}