1use std::{
2 collections::{HashMap, HashSet},
3 pin::Pin,
4 sync::Arc,
5 task::{Context, Poll},
6};
7
8use parking_lot::Mutex;
9use serde::{Deserialize, Serialize};
10use uuid::Uuid;
11
12use arrow::array::RecordBatch;
13use arrow::datatypes::Schema;
14use datafusion_common::DataFusionError;
15use datafusion_execution::{SendableRecordBatchStream, TaskContext};
16use datafusion_physical_plan::{
17 ExecutionPlan, RecordBatchStream, display::DisplayableExecutionPlan,
18 execution_plan::reset_plan_states, stream::RecordBatchStreamAdapter,
19};
20
21use futures::{Stream, StreamExt, TryStreamExt, future::join_all};
22use log::{debug, error, warn};
23use tokio::{sync::mpsc::Sender, task::AbortHandle};
24
25use crate::{
26 DistError, DistResult, JobId,
27 cluster::{DistCluster, NodeId, NodeStatus},
28 config::DistConfig,
29 event::{Event, EventHandler, local_jobs, send_event_with_timeout, start_event_handler},
30 executor::{DefaultExecutor, DistExecutor, logging_executor_metrics},
31 heartbeat::Heartbeater,
32 network::{DistNetwork, ScheduledTasks, StageInfo},
33 planner::{
34 DefaultPlanner, DisplayableStagePlans, DistPlanner, StageId, TaskId,
35 check_initial_stage_plans, resolve_stage_plan,
36 },
37 scheduler::{DefaultScheduler, DisplayableTaskDistribution, DistScheduler},
38 util::{ReceiverStreamBuilder, timestamp_ms},
39};
40
41#[derive(Debug, Clone)]
42pub struct DistRuntime {
43 pub node_id: NodeId,
44 pub status: Arc<Mutex<NodeStatus>>,
45 pub task_ctx: Arc<TaskContext>,
46 pub config: Arc<DistConfig>,
47 pub cluster: Arc<dyn DistCluster>,
48 pub network: Arc<dyn DistNetwork>,
49 pub planner: Arc<dyn DistPlanner>,
50 pub scheduler: Arc<dyn DistScheduler>,
51 pub executor: Arc<dyn DistExecutor>,
52 pub heartbeater: Arc<Heartbeater>,
53 pub stages: Arc<Mutex<HashMap<StageId, StageState>>>,
54 pub event_sender: Sender<Event>,
55}
56
57impl DistRuntime {
58 pub fn new(
59 task_ctx: Arc<TaskContext>,
60 config: Arc<DistConfig>,
61 cluster: Arc<dyn DistCluster>,
62 network: Arc<dyn DistNetwork>,
63 ) -> Self {
64 let node_id = network.local_node();
65 let status = Arc::new(Mutex::new(NodeStatus::Available));
66 let stages = Arc::new(Mutex::new(HashMap::new()));
67 let heartbeater = Heartbeater {
68 node_id: node_id.clone(),
69 cluster: cluster.clone(),
70 stages: stages.clone(),
71 heartbeat_interval: config.heartbeat_interval,
72 status: status.clone(),
73 };
74
75 let (sender, receiver) = tokio::sync::mpsc::channel::<Event>(config.event_queue_size);
76
77 let event_handler = EventHandler {
78 config: config.clone(),
79 cluster: cluster.clone(),
80 network: network.clone(),
81 local_stages: stages.clone(),
82 sender: sender.clone(),
83 receiver,
84 };
85 start_event_handler(event_handler);
86
87 Self {
88 node_id: network.local_node(),
89 status,
90 task_ctx,
91 config,
92 cluster,
93 network,
94 planner: Arc::new(DefaultPlanner),
95 scheduler: Arc::new(DefaultScheduler::new()),
96 executor: Arc::new(DefaultExecutor::new()),
97 heartbeater: Arc::new(heartbeater),
98 stages,
99 event_sender: sender,
100 }
101 }
102
103 pub fn with_planner(self, planner: Arc<dyn DistPlanner>) -> Self {
104 Self { planner, ..self }
105 }
106
107 pub fn with_scheduler(self, scheduler: Arc<dyn DistScheduler>) -> Self {
108 Self { scheduler, ..self }
109 }
110
111 pub fn with_executor(self, executor: Arc<dyn DistExecutor>) -> Self {
112 Self { executor, ..self }
113 }
114
115 pub async fn start(&self) {
116 self.heartbeater.start();
117 start_job_cleaner(self.stages.clone(), self.config.clone());
118 }
119
120 pub async fn shutdown(&self) {
121 *self.status.lock() = NodeStatus::Terminating;
123 debug!("Set node status to Terminating, no new tasks will be assigned");
124
125 self.heartbeater.send_heartbeat().await;
126 }
127
128 pub async fn submit(
129 &self,
130 job_id: impl Into<JobId>,
131 plan: Arc<dyn ExecutionPlan>,
132 job_meta: Arc<HashMap<String, String>>,
133 ) -> DistResult<HashMap<TaskId, NodeId>> {
134 let job_id = job_id.into();
135 debug!(
136 "Submitting job {job_id} with meta {job_meta:?} and physical plan: \n{}",
137 DisplayableExecutionPlan::new(plan.as_ref()).indent(true)
138 );
139
140 let mut stage_plans = self.planner.plan_stages(job_id.clone(), plan)?;
141 debug!(
142 "job {job_id} initial stage plans:\n{}",
143 DisplayableStagePlans(&stage_plans)
144 );
145 check_initial_stage_plans(job_id.clone(), &stage_plans)?;
146
147 let node_states = self.cluster.alive_nodes().await?;
148 debug!(
149 "alive nodes: {}",
150 node_states
151 .keys()
152 .map(|n| n.to_string())
153 .collect::<Vec<_>>()
154 .join(", ")
155 );
156
157 let task_distribution = self
158 .scheduler
159 .schedule(&self.node_id, &node_states, &stage_plans)
160 .await?;
161 debug!(
162 "job {job_id} task distribution: {}",
163 DisplayableTaskDistribution(&task_distribution)
164 );
165 let stage0_task_distribution: HashMap<TaskId, NodeId> = task_distribution
166 .iter()
167 .filter(|(task_id, _)| task_id.stage == 0)
168 .map(|(task_id, node_id)| (task_id.clone(), node_id.clone()))
169 .collect();
170 if stage0_task_distribution.is_empty() {
171 return Err(DistError::internal(format!(
172 "Not found stage0 task distribution in {task_distribution:?} for job {job_id}"
173 )));
174 }
175
176 for (_, stage_plan) in stage_plans.iter_mut() {
178 *stage_plan = resolve_stage_plan(stage_plan.clone(), &task_distribution, self.clone())?;
179 }
180 debug!(
181 "job {job_id} final stage plans:\n{}",
182 DisplayableStagePlans(&stage_plans)
183 );
184
185 let mut node_stages = HashMap::new();
186 let mut node_tasks = HashMap::new();
187 for (task_id, node_id) in task_distribution.iter() {
188 node_stages
189 .entry(node_id.clone())
190 .or_insert_with(HashSet::new)
191 .insert(task_id.stage_id());
192 node_tasks
193 .entry(node_id.clone())
194 .or_insert_with(Vec::new)
195 .push(task_id.clone());
196 }
197
198 let mut handles = Vec::with_capacity(node_stages.len());
200 for (node_id, stage_ids) in node_stages {
201 let node_stage_plans = stage_ids
202 .iter()
203 .map(|stage_id| {
204 (
205 stage_id.clone(),
206 stage_plans
207 .get(stage_id)
208 .cloned()
209 .expect("stage id should be valid"),
210 )
211 })
212 .collect::<HashMap<_, _>>();
213
214 let tasks = node_tasks.get(&node_id).cloned().unwrap_or_default();
215
216 let scheduled_tasks = ScheduledTasks::new(
217 node_stage_plans,
218 tasks,
219 Arc::new(task_distribution.clone()),
220 job_meta.clone(),
221 );
222
223 if node_id == self.node_id {
224 self.receive_tasks(scheduled_tasks).await?;
225 } else {
226 debug!(
227 "Sending job {job_id} tasks [{}] to {node_id}",
228 scheduled_tasks
229 .task_ids
230 .iter()
231 .map(|t| format!("{}/{}", t.stage, t.partition))
232 .collect::<Vec<String>>()
233 .join(", ")
234 );
235 let network = self.network.clone();
236 let handle = tokio::spawn(async move {
237 network.send_tasks(node_id.clone(), scheduled_tasks).await?;
238 Ok::<_, DistError>(())
239 });
240 handles.push(handle);
241 }
242 }
243
244 for handle in handles {
245 handle.await??;
246 }
247
248 logging_executor_metrics(self.executor.handle());
249
250 Ok(stage0_task_distribution)
251 }
252
253 pub async fn execute_local(&self, task_id: TaskId) -> DistResult<SendableRecordBatchStream> {
254 let stage_id = task_id.stage_id();
255
256 let mut guard = self.stages.lock();
257 let stage_state = guard
258 .get_mut(&stage_id)
259 .ok_or_else(|| DistError::internal(format!("Stage {stage_id} not found")))?;
260 let (task_set_id, plan) = stage_state.get_plan(task_id.partition as usize)?;
261 let schema = plan.schema();
262
263 let mut receiver_stream_builder = ReceiverStreamBuilder::new(2);
264
265 let tx = receiver_stream_builder.tx();
266 let partition = task_id.partition as usize;
267 let task_ctx = self.task_ctx.clone();
268 let driver_task = async move {
269 let mut df_stream = plan.execute(partition, task_ctx)?;
270
271 while let Some(batch) = df_stream.next().await {
272 let batch = batch.map_err(DistError::from);
273 match tx.send(batch).await {
274 Ok(()) => {}
275 Err(e) => {
276 warn!("Dist driver task failed to send batch to channel: {e}");
277 return Ok(());
278 }
279 }
280 }
281 Ok(()) as DistResult<()>
282 };
283
284 let abort_handle = receiver_stream_builder.spawn_on(driver_task, self.executor.handle());
285 stage_state.start_task(task_id.partition as usize, task_set_id, abort_handle)?;
286 drop(guard);
287
288 let stream = receiver_stream_builder.build();
289 let stream = Box::pin(RecordBatchStreamAdapter::new(
290 schema,
291 stream.map_err(DataFusionError::from),
292 ));
293
294 let task_stream = TaskStream::new(
295 task_id,
296 task_set_id,
297 self.stages.clone(),
298 self.event_sender.clone(),
299 stream,
300 );
301
302 Ok(Box::pin(task_stream))
303 }
304
305 pub async fn execute_remote(
306 &self,
307 node_id: NodeId,
308 task_id: TaskId,
309 ) -> DistResult<SendableRecordBatchStream> {
310 if node_id == self.node_id {
311 return Err(DistError::internal(format!(
312 "remote node id {node_id} is actually self"
313 )));
314 }
315
316 debug!("Executing remote task {task_id} on node {node_id}");
317 self.network.execute_task(node_id, task_id).await
318 }
319
320 pub async fn receive_tasks(&self, scheduled_tasks: ScheduledTasks) -> DistResult<()> {
321 if matches!(*self.status.lock(), NodeStatus::Terminating) {
322 return Err(DistError::internal(
323 "Local node is in Terminating status, cannot receive tasks",
324 ));
325 }
326
327 debug!(
328 "Received job {} tasks: [{}] and plans of stages: [{}]",
329 scheduled_tasks.job_id()?,
330 scheduled_tasks
331 .task_ids
332 .iter()
333 .map(|t| format!("{}/{}", t.stage, t.partition))
334 .collect::<Vec<String>>()
335 .join(", "),
336 scheduled_tasks
337 .stage_plans
338 .keys()
339 .map(|k| k.stage.to_string())
340 .collect::<Vec<String>>()
341 .join(", ")
342 );
343
344 let stage_states = StageState::from_scheduled_tasks(scheduled_tasks)?;
345 let stage_ids = stage_states.keys().cloned().collect::<Vec<StageId>>();
346 {
347 let mut guard = self.stages.lock();
348 guard.extend(stage_states);
349 drop(guard);
350 }
351
352 let stage0_ids = stage_ids
353 .iter()
354 .filter(|id| id.stage == 0)
355 .cloned()
356 .collect::<Vec<StageId>>();
357 if !stage0_ids.is_empty() {
358 send_event_with_timeout(&self.event_sender, Event::ReceivedStage0Tasks(stage0_ids))
359 .await?;
360 }
361
362 Ok(())
363 }
364
365 pub fn cleanup_local_jobs(&self, job_ids: Vec<JobId>) {
366 debug!(
367 "Cleaning up local Jobs [{}]",
368 job_ids
369 .iter()
370 .map(|id| id.to_string())
371 .collect::<Vec<_>>()
372 .join(", "),
373 );
374 let job_ids: HashSet<JobId> = job_ids.into_iter().collect();
375 if job_ids.is_empty() {
376 return;
377 }
378
379 cleanup_stages(&mut self.stages.lock(), |stage_id| {
380 job_ids.contains(&stage_id.job_id)
381 });
382 }
383
384 pub fn get_local_jobs(&self, job_ids: Option<&Vec<JobId>>) -> HashMap<StageId, StageInfo> {
385 local_jobs(&self.stages, job_ids)
386 }
387
388 pub async fn get_all_jobs(&self) -> DistResult<HashMap<StageId, StageInfo>> {
389 let mut combined_status = local_jobs(&self.stages, None);
391
392 let node_states = self.cluster.alive_nodes().await?;
394
395 let mut futures = Vec::new();
396 for node_id in node_states.keys() {
397 if *node_id != self.node_id {
398 let network = self.network.clone();
399 let node_id = node_id.clone();
400 futures.push(async move { network.get_jobs(node_id, None).await });
401 }
402 }
403
404 for remote_status in join_all(futures).await {
405 let remote_status = remote_status?;
406 for (stage_id, remote_stage_info) in remote_status {
407 combined_status
408 .entry(stage_id)
409 .and_modify(|existing| {
410 existing.merge(&remote_stage_info);
411 })
412 .or_insert(remote_stage_info);
413 }
414 }
415
416 Ok(combined_status)
417 }
418}
419
420#[derive(Debug)]
422pub struct StageState {
423 pub stage_id: StageId,
424 pub created_at_ms: i64,
425 pub stage_plan: Arc<dyn ExecutionPlan>,
426 pub assigned_partitions: HashSet<usize>,
428 pub task_sets: Vec<TaskSet>,
430 pub job_task_distribution: Arc<HashMap<TaskId, NodeId>>,
432 pub job_meta: Arc<HashMap<String, String>>,
433}
434
435impl StageState {
436 pub fn from_scheduled_tasks(
437 scheduled_tasks: ScheduledTasks,
438 ) -> DistResult<HashMap<StageId, StageState>> {
439 let mut stage_tasks: HashMap<StageId, HashSet<TaskId>> = HashMap::new();
440 for task_id in scheduled_tasks.task_ids {
441 let stage_id = task_id.stage_id();
442 stage_tasks.entry(stage_id).or_default().insert(task_id);
443 }
444
445 let mut stage_states = HashMap::new();
446 for (stage_id, assigned_task_ids) in stage_tasks {
447 let stage_state = StageState {
448 stage_id: stage_id.clone(),
449 created_at_ms: timestamp_ms(),
450 stage_plan: scheduled_tasks
451 .stage_plans
452 .get(&stage_id)
453 .ok_or_else(|| {
454 DistError::internal(format!("Not found plan of stage {stage_id}"))
455 })?
456 .clone(),
457 assigned_partitions: assigned_task_ids
458 .iter()
459 .map(|task_id| task_id.partition as usize)
460 .collect(),
461 task_sets: Vec::new(),
462 job_task_distribution: scheduled_tasks.job_task_distribution.clone(),
463 job_meta: scheduled_tasks.job_meta.clone(),
464 };
465 stage_states.insert(stage_id, stage_state);
466 }
467 Ok(stage_states)
468 }
469
470 pub fn num_running_tasks(&self) -> usize {
471 self.task_sets
472 .iter()
473 .map(|task_set| task_set.running_partitions.len())
474 .sum()
475 }
476
477 pub fn num_pending_tasks(&self) -> usize {
478 let executed_partitions: HashSet<usize> = self
479 .task_sets
480 .iter()
481 .flat_map(|task_set| {
482 let mut executed: HashSet<usize> =
483 task_set.running_partitions.keys().copied().collect();
484 executed.extend(task_set.dropped_partitions.keys());
485 executed
486 })
487 .collect();
488
489 self.assigned_partitions
490 .difference(&executed_partitions)
491 .count()
492 }
493
494 pub fn all_assigned_partitions_completed(&self) -> bool {
496 self.num_running_tasks() == 0 && self.num_pending_tasks() == 0
497 }
498
499 pub fn get_plan(&mut self, partition: usize) -> DistResult<(Uuid, Arc<dyn ExecutionPlan>)> {
500 if !self.assigned_partitions.contains(&partition) {
501 let task_id = self.stage_id.task_id(partition as u32);
502 return Err(DistError::internal(format!(
503 "Task {task_id} not found in this node"
504 )));
505 }
506
507 for task_set in self.task_sets.iter_mut() {
508 if task_set.never_executed(&partition) {
509 return Ok((task_set.id, task_set.shared_plan.clone()));
510 }
511 }
512
513 let task_set_id = Uuid::new_v4();
514 let new_task_set = TaskSet {
515 id: task_set_id,
516 shared_plan: reset_plan_states(self.stage_plan.clone())
517 .map_err(|e| DistError::internal(format!("Failed to reset plan state: {e}")))?,
518 running_partitions: HashMap::new(),
519 dropped_partitions: HashMap::new(),
520 };
521 let shared_plan = new_task_set.shared_plan.clone();
522 self.task_sets.push(new_task_set);
523
524 Ok((task_set_id, shared_plan))
525 }
526
527 pub fn start_task(
528 &mut self,
529 partition: usize,
530 task_set_id: Uuid,
531 abort_handle: AbortHandle,
532 ) -> DistResult<()> {
533 let task_set = self
534 .task_sets
535 .iter_mut()
536 .find(|task_set| task_set.id == task_set_id)
537 .ok_or_else(|| DistError::internal(format!("Task set {task_set_id} not found")))?;
538 task_set.running_partitions.insert(partition, abort_handle);
539 Ok(())
540 }
541
542 pub fn complete_task(&mut self, task_id: TaskId, task_set_id: Uuid, task_metrics: TaskMetrics) {
543 if let Some(task_set) = self
544 .task_sets
545 .iter_mut()
546 .find(|task_set| task_set.id == task_set_id)
547 {
548 task_set
549 .running_partitions
550 .remove(&(task_id.partition as usize));
551 task_set
552 .dropped_partitions
553 .insert(task_id.partition as usize, task_metrics);
554 }
555 }
556
557 pub fn never_executed(&self) -> bool {
558 self.task_sets
559 .iter()
560 .all(|set| set.running_partitions.is_empty() && set.dropped_partitions.is_empty())
561 }
562
563 pub fn abort_running_tasks(&mut self) {
564 for task_set in &mut self.task_sets {
565 task_set.abort_running_partitions();
566 }
567 }
568}
569
570#[derive(Debug)]
571pub struct TaskSet {
572 pub id: Uuid,
573 pub shared_plan: Arc<dyn ExecutionPlan>,
574 pub running_partitions: HashMap<usize, AbortHandle>,
575 pub dropped_partitions: HashMap<usize, TaskMetrics>,
576}
577
578impl TaskSet {
579 pub fn never_executed(&self, partition: &usize) -> bool {
580 !self.running_partitions.contains_key(partition)
581 && !self.dropped_partitions.contains_key(partition)
582 }
583
584 pub fn abort_running_partitions(&mut self) {
585 for (_, abort_handle) in self.running_partitions.drain() {
586 abort_handle.abort();
587 }
588 }
589}
590
591#[derive(Debug, Clone, Serialize, Deserialize)]
592pub struct TaskMetrics {
593 pub output_rows: usize,
594 pub output_bytes: usize,
595 pub completed: bool,
596}
597
598pub struct TaskStream {
599 pub task_id: TaskId,
600 pub task_set_id: Uuid,
601 pub stages: Arc<Mutex<HashMap<StageId, StageState>>>,
602 pub event_sender: Sender<Event>,
603 pub stream: SendableRecordBatchStream,
604 pub output_rows: usize,
605 pub output_bytes: usize,
606 pub completed: bool,
607}
608
609impl TaskStream {
610 pub fn new(
611 task_id: TaskId,
612 task_set_id: Uuid,
613 stages: Arc<Mutex<HashMap<StageId, StageState>>>,
614 event_sender: Sender<Event>,
615 stream: SendableRecordBatchStream,
616 ) -> Self {
617 Self {
618 task_id,
619 task_set_id,
620 stages,
621 event_sender,
622 stream,
623 output_rows: 0,
624 output_bytes: 0,
625 completed: false,
626 }
627 }
628}
629
630impl Stream for TaskStream {
631 type Item = Result<RecordBatch, DataFusionError>;
632
633 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
634 match self.stream.as_mut().poll_next(cx) {
635 Poll::Ready(Some(Ok(batch))) => {
636 self.output_rows += batch.num_rows();
637 self.output_bytes += batch.get_array_memory_size();
638 Poll::Ready(Some(Ok(batch)))
639 }
640 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
641 Poll::Ready(None) => {
642 self.completed = true;
643 Poll::Ready(None)
644 }
645 Poll::Pending => Poll::Pending,
646 }
647 }
648}
649
650impl RecordBatchStream for TaskStream {
651 fn schema(&self) -> Arc<Schema> {
652 self.stream.schema()
653 }
654}
655
656impl Drop for TaskStream {
657 fn drop(&mut self) {
658 let task_id = self.task_id.clone();
659 let task_set_id = self.task_set_id;
660 let task_metrics = TaskMetrics {
661 output_bytes: self.output_bytes,
662 output_rows: self.output_rows,
663 completed: self.completed,
664 };
665 debug!("Task {task_id} dropped with metrics: {task_metrics:?}");
666
667 let should_send = {
668 let mut guard = self.stages.lock();
669 if let Some(stage_state) = guard.get_mut(&task_id.stage_id()) {
670 stage_state.complete_task(task_id.clone(), task_set_id, task_metrics);
671 stage_state.stage_id.stage == 0 && stage_state.all_assigned_partitions_completed()
672 } else {
673 false
674 }
675 };
676
677 if should_send {
678 let event = Event::CheckJobCompleted(task_id.job_id.clone());
679 if let Err(e) = self.event_sender.try_send(event) {
680 error!(
681 "Failed to send CheckJobCompleted event after task {task_id} stream dropped: {e}"
682 );
683 }
684 }
685 }
686}
687
688fn start_job_cleaner(stages: Arc<Mutex<HashMap<StageId, StageState>>>, config: Arc<DistConfig>) {
689 tokio::spawn(async move {
690 loop {
691 tokio::time::sleep(config.job_ttl_check_interval).await;
692
693 let mut guard = stages.lock();
694 let mut to_cleanup = Vec::new();
695 for (stage_id, stage_state) in guard.iter() {
696 let age_ms = timestamp_ms() - stage_state.created_at_ms;
697 if age_ms >= config.job_ttl.as_millis() as i64 {
698 to_cleanup.push(stage_id.clone());
699 }
700 }
701
702 if !to_cleanup.is_empty() {
703 debug!(
704 "Stages [{}] lifetime exceed job ttl {}s, cleaning up.",
705 to_cleanup
706 .iter()
707 .map(|id| id.to_string())
708 .collect::<Vec<_>>()
709 .join(", "),
710 config.job_ttl.as_secs()
711 );
712 cleanup_stages(&mut guard, |stage_id| to_cleanup.contains(stage_id));
713 }
714 drop(guard);
715 }
716 });
717}
718
719pub(crate) fn cleanup_stages(
720 stages: &mut HashMap<StageId, StageState>,
721 mut should_cleanup: impl FnMut(&StageId) -> bool,
722) {
723 stages.retain(|stage_id, stage_state| {
724 if should_cleanup(stage_id) {
725 stage_state.abort_running_tasks();
726 false
727 } else {
728 true
729 }
730 });
731}