1use std::{
2 collections::{HashMap, HashSet},
3 pin::Pin,
4 sync::Arc,
5 task::{Context, Poll},
6};
7
8use parking_lot::Mutex;
9use serde::{Deserialize, Serialize};
10use uuid::Uuid;
11
12use arrow::array::RecordBatch;
13use datafusion_execution::TaskContext;
14use datafusion_physical_plan::ExecutionPlan;
15
16use futures::{Stream, StreamExt};
17use log::{debug, error};
18use tokio::sync::mpsc::Sender;
19
20use crate::{
21 DistError, DistResult, RecordBatchStream,
22 cluster::{DistCluster, NodeId, NodeStatus},
23 config::DistConfig,
24 event::{Event, EventHandler, cleanup_job, local_stage_stats, start_event_handler},
25 executor::{DefaultExecutor, DistExecutor, logging_executor_metrics},
26 heartbeat::Heartbeater,
27 network::{DistNetwork, ScheduledTasks, StageInfo},
28 planner::{
29 DefaultPlanner, DisplayableStagePlans, DistPlanner, StageId, TaskId,
30 check_initial_stage_plans, resolve_stage_plan,
31 },
32 scheduler::{DefaultScheduler, DisplayableTaskDistribution, DistScheduler},
33 util::{ReceiverStreamBuilder, timestamp_ms},
34};
35
36#[derive(Debug, Clone)]
37pub struct DistRuntime {
38 pub node_id: NodeId,
39 pub status: Arc<Mutex<NodeStatus>>,
40 pub task_ctx: Arc<TaskContext>,
41 pub config: Arc<DistConfig>,
42 pub cluster: Arc<dyn DistCluster>,
43 pub network: Arc<dyn DistNetwork>,
44 pub planner: Arc<dyn DistPlanner>,
45 pub scheduler: Arc<dyn DistScheduler>,
46 pub executor: Arc<dyn DistExecutor>,
47 pub heartbeater: Arc<Heartbeater>,
48 pub stages: Arc<Mutex<HashMap<StageId, StageState>>>,
49 pub event_sender: Sender<Event>,
50}
51
52impl DistRuntime {
53 pub fn new(
54 task_ctx: Arc<TaskContext>,
55 config: Arc<DistConfig>,
56 cluster: Arc<dyn DistCluster>,
57 network: Arc<dyn DistNetwork>,
58 ) -> Self {
59 let node_id = network.local_node();
60 let status = Arc::new(Mutex::new(NodeStatus::Available));
61 let stages = Arc::new(Mutex::new(HashMap::new()));
62 let heartbeater = Heartbeater {
63 node_id: node_id.clone(),
64 cluster: cluster.clone(),
65 stages: stages.clone(),
66 heartbeat_interval: config.heartbeat_interval,
67 status: status.clone(),
68 };
69
70 let (sender, receiver) = tokio::sync::mpsc::channel::<Event>(1024);
71
72 let event_handler = EventHandler {
73 local_node: node_id.clone(),
74 config: config.clone(),
75 cluster: cluster.clone(),
76 network: network.clone(),
77 local_stages: stages.clone(),
78 sender: sender.clone(),
79 receiver,
80 };
81 start_event_handler(event_handler);
82
83 Self {
84 node_id: network.local_node(),
85 status,
86 task_ctx,
87 config,
88 cluster,
89 network,
90 planner: Arc::new(DefaultPlanner),
91 scheduler: Arc::new(DefaultScheduler::new()),
92 executor: Arc::new(DefaultExecutor::new()),
93 heartbeater: Arc::new(heartbeater),
94 stages,
95 event_sender: sender,
96 }
97 }
98
99 pub fn with_planner(self, planner: Arc<dyn DistPlanner>) -> Self {
100 Self { planner, ..self }
101 }
102
103 pub fn with_scheduler(self, scheduler: Arc<dyn DistScheduler>) -> Self {
104 Self { scheduler, ..self }
105 }
106
107 pub fn with_executor(self, executor: Arc<dyn DistExecutor>) -> Self {
108 Self { executor, ..self }
109 }
110
111 pub async fn start(&self) {
112 self.heartbeater.start();
113 start_job_cleaner(self.stages.clone(), self.config.clone());
114 }
115
116 pub async fn shutdown(&self) {
117 *self.status.lock() = NodeStatus::Terminating;
119 debug!("Set node status to Terminating, no new tasks will be assigned");
120
121 self.heartbeater.send_heartbeat().await;
122 }
123
124 pub async fn submit(
125 &self,
126 plan: Arc<dyn ExecutionPlan>,
127 ) -> DistResult<(Uuid, HashMap<TaskId, NodeId>)> {
128 let job_id = Uuid::new_v4();
129 let mut stage_plans = self.planner.plan_stages(job_id, plan)?;
130 debug!(
131 "job {job_id} initial stage plans:\n{}",
132 DisplayableStagePlans(&stage_plans)
133 );
134 check_initial_stage_plans(job_id, &stage_plans)?;
135
136 let node_states = self.cluster.alive_nodes().await?;
137 debug!(
138 "alive nodes: {}",
139 node_states
140 .keys()
141 .map(|n| n.to_string())
142 .collect::<Vec<_>>()
143 .join(", ")
144 );
145
146 let task_distribution = self
147 .scheduler
148 .schedule(&self.node_id, &node_states, &stage_plans)
149 .await?;
150 debug!(
151 "job {job_id} task distribution: {}",
152 DisplayableTaskDistribution(&task_distribution)
153 );
154 let stage0_task_distribution: HashMap<TaskId, NodeId> = task_distribution
155 .iter()
156 .filter(|(task_id, _)| task_id.stage == 0)
157 .map(|(task_id, node_id)| (*task_id, node_id.clone()))
158 .collect();
159 if stage0_task_distribution.is_empty() {
160 return Err(DistError::internal(format!(
161 "Not found stage0 task distribution in {task_distribution:?} for job {job_id}"
162 )));
163 }
164
165 for (_, stage_plan) in stage_plans.iter_mut() {
167 *stage_plan = resolve_stage_plan(stage_plan.clone(), &task_distribution, self.clone())?;
168 }
169 debug!(
170 "job {job_id} final stage plans:\n{}",
171 DisplayableStagePlans(&stage_plans)
172 );
173
174 let mut node_stages = HashMap::new();
175 let mut node_tasks = HashMap::new();
176 for (task_id, node_id) in task_distribution.iter() {
177 node_stages
178 .entry(node_id.clone())
179 .or_insert_with(HashSet::new)
180 .insert(task_id.stage_id());
181 node_tasks
182 .entry(node_id.clone())
183 .or_insert_with(Vec::new)
184 .push(*task_id);
185 }
186
187 let mut handles = Vec::with_capacity(node_stages.len());
189 for (node_id, stage_ids) in node_stages {
190 let node_stage_plans = stage_ids
191 .iter()
192 .map(|stage_id| {
193 (
194 *stage_id,
195 stage_plans
196 .get(stage_id)
197 .cloned()
198 .expect("stage id should be valid"),
199 )
200 })
201 .collect::<HashMap<_, _>>();
202
203 let tasks = node_tasks.get(&node_id).cloned().unwrap_or_default();
204
205 let scheduled_tasks =
206 ScheduledTasks::new(node_stage_plans, tasks, Arc::new(task_distribution.clone()));
207
208 if node_id == self.node_id {
209 self.receive_tasks(scheduled_tasks).await?;
210 } else {
211 debug!(
212 "Sending job {job_id} tasks [{}] to {node_id}",
213 scheduled_tasks
214 .task_ids
215 .iter()
216 .map(|t| format!("{}/{}", t.stage, t.partition))
217 .collect::<Vec<String>>()
218 .join(", ")
219 );
220 let network = self.network.clone();
221 let handle = tokio::spawn(async move {
222 network.send_tasks(node_id.clone(), scheduled_tasks).await?;
223 Ok::<_, DistError>(())
224 });
225 handles.push(handle);
226 }
227 }
228
229 for handle in handles {
230 handle.await??;
231 }
232
233 logging_executor_metrics(self.executor.handle());
234
235 Ok((job_id, stage0_task_distribution))
236 }
237
238 pub async fn execute_local(&self, task_id: TaskId) -> DistResult<RecordBatchStream> {
239 let stage_id = task_id.stage_id();
240
241 let mut guard = self.stages.lock();
242 let stage_state = guard
243 .get_mut(&stage_id)
244 .ok_or_else(|| DistError::internal(format!("Stage {stage_id} not found")))?;
245 let (task_set_id, plan) = stage_state.get_plan(task_id.partition as usize)?;
246 drop(guard);
247
248 let mut receiver_stream_builder = ReceiverStreamBuilder::new(2);
249
250 let tx = receiver_stream_builder.tx();
251 let partition = task_id.partition as usize;
252 let task_ctx = self.task_ctx.clone();
253 let driver_task = async move {
254 let mut df_stream = plan.execute(partition, task_ctx)?;
255
256 while let Some(batch) = df_stream.next().await {
257 let batch = batch.map_err(DistError::from);
258 if tx.send(batch).await.is_err() {
259 return Ok(());
261 }
262 }
263 Ok(()) as DistResult<()>
264 };
265
266 receiver_stream_builder.spawn_on(driver_task, self.executor.handle());
267
268 let stream = receiver_stream_builder.build();
269
270 let task_stream = TaskStream::new(
271 task_id,
272 task_set_id,
273 self.stages.clone(),
274 self.event_sender.clone(),
275 stream,
276 );
277
278 Ok(task_stream.boxed())
279 }
280
281 pub async fn execute_remote(
282 &self,
283 node_id: NodeId,
284 task_id: TaskId,
285 ) -> DistResult<RecordBatchStream> {
286 if node_id == self.node_id {
287 return Err(DistError::internal(format!(
288 "remote node id {node_id} is actually self"
289 )));
290 }
291
292 debug!("Executing remote task {task_id} on node {node_id}");
293 self.network.execute_task(node_id, task_id).await
294 }
295
296 pub async fn receive_tasks(&self, scheduled_tasks: ScheduledTasks) -> DistResult<()> {
297 debug!(
298 "Received job {} tasks: [{}] and plans of stages: [{}]",
299 scheduled_tasks.job_id()?,
300 scheduled_tasks
301 .task_ids
302 .iter()
303 .map(|t| format!("{}/{}", t.stage, t.partition))
304 .collect::<Vec<String>>()
305 .join(", "),
306 scheduled_tasks
307 .stage_plans
308 .keys()
309 .map(|k| k.stage.to_string())
310 .collect::<Vec<String>>()
311 .join(", ")
312 );
313
314 let stage_states = StageState::from_scheduled_tasks(scheduled_tasks)?;
315 let stage_ids = stage_states.keys().cloned().collect::<Vec<StageId>>();
316 {
317 let mut guard = self.stages.lock();
318 guard.extend(stage_states);
319 drop(guard);
320 }
321
322 let stage0_ids = stage_ids
323 .iter()
324 .filter(|id| id.stage == 0)
325 .cloned()
326 .collect::<Vec<StageId>>();
327 if !stage0_ids.is_empty() {
328 self.event_sender
329 .send(Event::ReceivedStage0Tasks(stage0_ids))
330 .await
331 .map_err(|e| {
332 DistError::internal(format!("Failed to send ReceivedStage0Tasks event: {e}"))
333 })?;
334 }
335
336 Ok(())
337 }
338
339 pub fn cleanup_local_job(&self, job_id: Uuid) {
340 let mut guard = self.stages.lock();
341 let stage_ids = guard
342 .iter()
343 .filter(|(stage_id, _)| stage_id.job_id == job_id)
344 .map(|(stage_id, _)| stage_id)
345 .collect::<Vec<_>>();
346 if !stage_ids.is_empty() {
347 debug!(
348 "Cleaning up local Job {job_id} stages [{}]",
349 stage_ids
350 .iter()
351 .map(|id| id.stage.to_string())
352 .collect::<Vec<_>>()
353 .join(", ")
354 );
355 guard.retain(|stage_id, _| stage_id.job_id != job_id);
356 }
357 }
358
359 pub async fn cleanup_job(&self, job_id: Uuid) -> DistResult<()> {
360 cleanup_job(
361 &self.node_id,
362 &self.cluster,
363 &self.network,
364 &self.stages,
365 job_id,
366 )
367 .await
368 }
369
370 pub fn get_local_job(&self, job_id: Uuid) -> HashMap<StageId, StageInfo> {
371 local_stage_stats(&self.stages, Some(job_id))
372 }
373
374 pub fn get_local_jobs(&self) -> HashMap<Uuid, HashMap<StageId, StageInfo>> {
375 let stage_stat = local_stage_stats(&self.stages, None);
376
377 let mut job_stats: HashMap<Uuid, HashMap<StageId, StageInfo>> = HashMap::new();
379 for (stage_id, stage_info) in stage_stat {
380 job_stats
381 .entry(stage_id.job_id)
382 .or_default()
383 .insert(stage_id, stage_info);
384 }
385
386 job_stats
387 }
388
389 pub async fn get_all_jobs(&self) -> DistResult<HashMap<Uuid, HashMap<StageId, StageInfo>>> {
390 let mut combined_status = local_stage_stats(&self.stages, None);
392
393 let node_states = self.cluster.alive_nodes().await?;
395
396 let mut handles = Vec::new();
397 for node_id in node_states.keys() {
398 if *node_id != self.node_id {
399 let network = self.network.clone();
400 let node_id = node_id.clone();
401 let handle =
402 tokio::spawn(async move { network.get_job_status(node_id, None).await });
403 handles.push(handle);
404 }
405 }
406
407 for handle in handles {
408 let remote_status = handle.await??;
409 for (stage_id, remote_stage_info) in remote_status {
410 combined_status
411 .entry(stage_id)
412 .and_modify(|existing| {
413 existing
414 .assigned_partitions
415 .extend(&remote_stage_info.assigned_partitions);
416 existing
417 .task_set_infos
418 .extend(remote_stage_info.task_set_infos.clone());
419 })
420 .or_insert(remote_stage_info);
421 }
422 }
423
424 let mut job_stats: HashMap<Uuid, HashMap<StageId, StageInfo>> = HashMap::new();
426 for (stage_id, stage_info) in combined_status {
427 job_stats
428 .entry(stage_id.job_id)
429 .or_default()
430 .insert(stage_id, stage_info);
431 }
432
433 Ok(job_stats)
434 }
435}
436
437#[derive(Debug)]
438pub struct StageState {
439 pub stage_id: StageId,
440 pub create_at_ms: i64,
441 pub stage_plan: Arc<dyn ExecutionPlan>,
442 pub assigned_partitions: HashSet<usize>,
443 pub task_sets: Vec<TaskSet>,
444 pub job_task_distribution: Arc<HashMap<TaskId, NodeId>>,
445}
446
447impl StageState {
448 pub fn from_scheduled_tasks(
449 scheduled_tasks: ScheduledTasks,
450 ) -> DistResult<HashMap<StageId, StageState>> {
451 let mut stage_tasks: HashMap<StageId, HashSet<TaskId>> = HashMap::new();
452 for task_id in scheduled_tasks.task_ids {
453 let stage_id = task_id.stage_id();
454 stage_tasks.entry(stage_id).or_default().insert(task_id);
455 }
456
457 let mut stage_states = HashMap::new();
458 for (stage_id, assigned_task_ids) in stage_tasks {
459 let stage_state = StageState {
460 stage_id,
461 create_at_ms: timestamp_ms(),
462 stage_plan: scheduled_tasks
463 .stage_plans
464 .get(&stage_id)
465 .ok_or_else(|| {
466 DistError::internal(format!("Not found plan of stage {stage_id}"))
467 })?
468 .clone(),
469 assigned_partitions: assigned_task_ids
470 .iter()
471 .map(|task_id| task_id.partition as usize)
472 .collect(),
473 task_sets: Vec::new(),
474 job_task_distribution: scheduled_tasks.job_task_distribution.clone(),
475 };
476 stage_states.insert(stage_id, stage_state);
477 }
478 Ok(stage_states)
479 }
480
481 pub fn num_running_tasks(&self) -> usize {
482 self.task_sets
483 .iter()
484 .map(|task_set| task_set.running_partitions.len())
485 .sum()
486 }
487
488 pub fn get_plan(&mut self, partition: usize) -> DistResult<(Uuid, Arc<dyn ExecutionPlan>)> {
489 if !self.assigned_partitions.contains(&partition) {
490 let task_id = self.stage_id.task_id(partition as u32);
491 return Err(DistError::internal(format!(
492 "Task {task_id} not found in this node"
493 )));
494 }
495
496 for task_set in self.task_sets.iter_mut() {
497 if !task_set.never_executed(&partition) {
498 task_set.running_partitions.insert(partition);
499 return Ok((task_set.id, task_set.shared_plan.clone()));
500 }
501 }
502
503 let task_set_id = Uuid::new_v4();
504 let mut new_task_set = TaskSet {
505 id: task_set_id,
506 shared_plan: self.stage_plan.clone().reset_state()?,
507 running_partitions: HashSet::new(),
508 dropped_partitions: HashMap::new(),
509 };
510 new_task_set.running_partitions.insert(partition);
511 let shared_plan = new_task_set.shared_plan.clone();
512 self.task_sets.push(new_task_set);
513
514 Ok((task_set_id, shared_plan))
515 }
516
517 pub fn complete_task(&mut self, task_id: TaskId, task_set_id: Uuid, task_metrics: TaskMetrics) {
518 if let Some(task_set) = self
519 .task_sets
520 .iter_mut()
521 .find(|task_set| task_set.id == task_set_id)
522 {
523 task_set
524 .running_partitions
525 .remove(&(task_id.partition as usize));
526 task_set
527 .dropped_partitions
528 .insert(task_id.partition as usize, task_metrics);
529 }
530 }
531
532 pub fn assigned_partitions_executed_at_least_once(&self) -> bool {
533 let executed_partitions = self
534 .task_sets
535 .iter()
536 .flat_map(|task_set| {
537 let mut executed = task_set.running_partitions.clone();
538 executed.extend(task_set.dropped_partitions.keys());
539 executed
540 })
541 .collect::<HashSet<_>>();
542
543 for partition in self.assigned_partitions.iter() {
544 if !executed_partitions.contains(partition) {
545 return false;
546 }
547 }
548 true
549 }
550
551 pub fn never_executed(&self) -> bool {
552 self.task_sets
553 .iter()
554 .all(|set| set.running_partitions.is_empty() && set.dropped_partitions.is_empty())
555 }
556}
557
558#[derive(Debug)]
559pub struct TaskSet {
560 pub id: Uuid,
561 pub shared_plan: Arc<dyn ExecutionPlan>,
562 pub running_partitions: HashSet<usize>,
563 pub dropped_partitions: HashMap<usize, TaskMetrics>,
564}
565
566impl TaskSet {
567 pub fn never_executed(&self, partition: &usize) -> bool {
568 !self.running_partitions.contains(partition)
569 && !self.dropped_partitions.contains_key(partition)
570 }
571}
572
573#[derive(Debug, Clone, Serialize, Deserialize)]
574pub struct TaskMetrics {
575 pub output_rows: usize,
576 pub output_bytes: usize,
577 pub completed: bool,
578}
579
580pub struct TaskStream {
581 pub task_id: TaskId,
582 pub task_set_id: Uuid,
583 pub stages: Arc<Mutex<HashMap<StageId, StageState>>>,
584 pub event_sender: Sender<Event>,
585 pub stream: RecordBatchStream,
586 pub output_rows: usize,
587 pub output_bytes: usize,
588 pub completed: bool,
589}
590
591impl TaskStream {
592 pub fn new(
593 task_id: TaskId,
594 task_set_id: Uuid,
595 stages: Arc<Mutex<HashMap<StageId, StageState>>>,
596 event_sender: Sender<Event>,
597 stream: RecordBatchStream,
598 ) -> Self {
599 Self {
600 task_id,
601 task_set_id,
602 stages,
603 event_sender,
604 stream,
605 output_rows: 0,
606 output_bytes: 0,
607 completed: false,
608 }
609 }
610}
611
612impl Stream for TaskStream {
613 type Item = DistResult<RecordBatch>;
614
615 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
616 match self.stream.as_mut().poll_next(cx) {
617 Poll::Ready(Some(Ok(batch))) => {
618 self.output_rows += batch.num_rows();
619 self.output_bytes += batch.get_array_memory_size();
620 Poll::Ready(Some(Ok(batch)))
621 }
622 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
623 Poll::Ready(None) => {
624 self.completed = true;
625 Poll::Ready(None)
626 }
627 Poll::Pending => Poll::Pending,
628 }
629 }
630}
631
632impl Drop for TaskStream {
633 fn drop(&mut self) {
634 let task_id = self.task_id;
635 let task_set_id = self.task_set_id;
636 let task_metrics = TaskMetrics {
637 output_bytes: self.output_bytes,
638 output_rows: self.output_rows,
639 completed: self.completed,
640 };
641 debug!("Task {task_id} dropped with metrics: {task_metrics:?}");
642
643 let stages = self.stages.clone();
644 let sender = self.event_sender.clone();
645 tokio::spawn(async move {
646 let mut send_event = false;
647 {
648 let mut guard = stages.lock();
649 if let Some(stage_state) = guard.get_mut(&task_id.stage_id()) {
650 stage_state.complete_task(task_id, task_set_id, task_metrics);
651 if stage_state.stage_id.stage == 0
652 && stage_state.assigned_partitions_executed_at_least_once()
653 {
654 send_event = true;
655 }
656 }
657 }
658
659 if send_event
660 && let Err(e) = sender.send(Event::CheckJobCompleted(task_id.job_id)).await
661 {
662 error!(
663 "Failed to send CheckJobCompleted event after task {task_id} stream dropped: {e}"
664 );
665 }
666 });
667 }
668}
669
670fn start_job_cleaner(stages: Arc<Mutex<HashMap<StageId, StageState>>>, config: Arc<DistConfig>) {
671 tokio::spawn(async move {
672 loop {
673 tokio::time::sleep(config.job_ttl_check_interval).await;
674
675 let mut guard = stages.lock();
676 let mut to_cleanup = Vec::new();
677 for (stage_id, stage_state) in guard.iter() {
678 let age_ms = timestamp_ms() - stage_state.create_at_ms;
679 if age_ms >= config.job_ttl.as_millis() as i64 {
680 to_cleanup.push(*stage_id);
681 }
682 }
683
684 if !to_cleanup.is_empty() {
685 debug!(
686 "Stages [{}] lifetime exceed job ttl {}s, cleaning up.",
687 to_cleanup
688 .iter()
689 .map(|id| id.to_string())
690 .collect::<Vec<_>>()
691 .join(", "),
692 config.job_ttl.as_secs()
693 );
694 guard.retain(|stage_id, _| !to_cleanup.contains(stage_id));
695 }
696 drop(guard);
697 }
698 });
699}