kapot_scheduler/state/
execution_graph.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::{HashMap, HashSet};
19use std::convert::TryInto;
20use std::fmt::{Debug, Formatter};
21use std::iter::FromIterator;
22use std::sync::Arc;
23use std::time::{SystemTime, UNIX_EPOCH};
24
25use datafusion::physical_plan::display::DisplayableExecutionPlan;
26use datafusion::physical_plan::{accept, ExecutionPlan, ExecutionPlanVisitor};
27use datafusion::prelude::SessionContext;
28use datafusion_proto::logical_plan::AsLogicalPlan;
29use log::{error, info, warn};
30
31use kapot_core::error::{KapotError, Result};
32use kapot_core::execution_plans::{ShuffleWriterExec, UnresolvedShuffleExec};
33use kapot_core::serde::protobuf::failed_task::FailedReason;
34use kapot_core::serde::protobuf::job_status::Status;
35use kapot_core::serde::protobuf::{
36    self, execution_graph_stage::StageType, FailedTask, JobStatus, ResultLost,
37    RunningJob, SuccessfulJob, TaskStatus,
38};
39use kapot_core::serde::protobuf::{job_status, FailedJob, ShuffleWritePartition};
40use kapot_core::serde::protobuf::{task_status, RunningTask};
41use kapot_core::serde::scheduler::{
42    ExecutorMetadata, PartitionId, PartitionLocation, PartitionStats,
43};
44use kapot_core::serde::KapotCodec;
45use datafusion_proto::physical_plan::AsExecutionPlan;
46
47use crate::display::print_stage_metrics;
48use crate::planner::DistributedPlanner;
49use crate::scheduler_server::event::QueryStageSchedulerEvent;
50use crate::scheduler_server::timestamp_millis;
51use crate::state::execution_graph::execution_stage::RunningStage;
52pub(crate) use crate::state::execution_graph::execution_stage::{
53    ExecutionStage, FailedStage, ResolvedStage, StageOutput, SuccessfulStage, TaskInfo,
54    UnresolvedStage,
55};
56use crate::state::task_manager::UpdatedStages;
57
58mod execution_stage;
59
60/// Represents the DAG for a distributed query plan.
61///
62/// A distributed query plan consists of a set of stages which must be executed sequentially.
63///
64/// Each stage consists of a set of partitions which can be executed in parallel, where each partition
65/// represents a `Task`, which is the basic unit of scheduling in kapot.
66///
67/// As an example, consider a SQL query which performs a simple aggregation:
68///
69/// `SELECT id, SUM(gmv) FROM some_table GROUP BY id`
70///
71/// This will produce a DataFusion execution plan that looks something like
72///
73///
74///   CoalesceBatchesExec: target_batch_size=4096
75///     RepartitionExec: partitioning=Hash([Column { name: "id", index: 0 }], 4)
76///       AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[SUM(some_table.gmv)]
77///         TableScan: some_table
78///
79/// The kapot `DistributedPlanner` will turn this into a distributed plan by creating a shuffle
80/// boundary (called a "Stage") whenever the underlying plan needs to perform a repartition.
81/// In this case we end up with a distributed plan with two stages:
82///
83///
84/// ExecutionGraph[job_id=job, session_id=session, available_tasks=1, complete=false]
85/// =========UnResolvedStage[id=2, children=1]=========
86/// Inputs{1: StageOutput { partition_locations: {}, complete: false }}
87/// ShuffleWriterExec: None
88///   AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[SUM(?table?.gmv)]
89///     CoalesceBatchesExec: target_batch_size=4096
90///       UnresolvedShuffleExec
91/// =========ResolvedStage[id=1, partitions=1]=========
92/// ShuffleWriterExec: Some(Hash([Column { name: "id", index: 0 }], 4))
93///   AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[SUM(?table?.gmv)]
94///     TableScan: some_table
95///
96///
97/// The DAG structure of this `ExecutionGraph` is encoded in the stages. Each stage's `input` field
98/// will indicate which stages it depends on, and each stage's `output_links` will indicate which
99/// stage it needs to publish its output to.
100///
101/// If a stage has `output_links` is empty then it is the final stage in this query, and it should
102/// publish its outputs to the `ExecutionGraph`s `output_locations` representing the final query results.
103#[derive(Clone)]
104pub struct ExecutionGraph {
105    /// Curator scheduler name. Can be `None` is `ExecutionGraph` is not currently curated by any scheduler
106    scheduler_id: Option<String>,
107    /// ID for this job
108    job_id: String,
109    /// Job name, can be empty string
110    job_name: String,
111    /// Session ID for this job
112    session_id: String,
113    /// Status of this job
114    status: JobStatus,
115    /// Timestamp of when this job was submitted
116    queued_at: u64,
117    /// Job start time
118    start_time: u64,
119    /// Job end time
120    end_time: u64,
121    /// Map from Stage ID -> ExecutionStage
122    stages: HashMap<usize, ExecutionStage>,
123    /// Total number fo output partitions
124    output_partitions: usize,
125    /// Locations of this `ExecutionGraph` final output locations
126    output_locations: Vec<PartitionLocation>,
127    /// Task ID generator, generate unique TID in the execution graph
128    task_id_gen: usize,
129    /// Failed stage attempts, record the failed stage attempts to limit the retry times.
130    /// Map from Stage ID -> Set<Stage_ATTPMPT_NUM>
131    failed_stage_attempts: HashMap<usize, HashSet<usize>>,
132}
133
134#[derive(Clone, Debug)]
135pub struct RunningTaskInfo {
136    pub task_id: usize,
137    pub job_id: String,
138    pub stage_id: usize,
139    pub partition_id: usize,
140    pub executor_id: String,
141}
142
143impl ExecutionGraph {
144    pub fn new(
145        scheduler_id: &str,
146        job_id: &str,
147        job_name: &str,
148        session_id: &str,
149        plan: Arc<dyn ExecutionPlan>,
150        queued_at: u64,
151    ) -> Result<Self> {
152        let mut planner = DistributedPlanner::new();
153
154        let output_partitions = plan.properties().output_partitioning().partition_count();
155
156        let shuffle_stages = planner.plan_query_stages(job_id, plan)?;
157
158        let builder = ExecutionStageBuilder::new();
159        let stages = builder.build(shuffle_stages)?;
160
161        let started_at = timestamp_millis();
162
163        Ok(Self {
164            scheduler_id: Some(scheduler_id.to_string()),
165            job_id: job_id.to_string(),
166            job_name: job_name.to_string(),
167            session_id: session_id.to_string(),
168            status: JobStatus {
169                job_id: job_id.to_string(),
170                job_name: job_name.to_string(),
171                status: Some(Status::Running(RunningJob {
172                    queued_at,
173                    started_at,
174                    scheduler: scheduler_id.to_string(),
175                })),
176            },
177            queued_at,
178            start_time: started_at,
179            end_time: 0,
180            stages,
181            output_partitions,
182            output_locations: vec![],
183            task_id_gen: 0,
184            failed_stage_attempts: HashMap::new(),
185        })
186    }
187
188    pub fn job_id(&self) -> &str {
189        self.job_id.as_str()
190    }
191
192    pub fn job_name(&self) -> &str {
193        self.job_name.as_str()
194    }
195
196    pub fn session_id(&self) -> &str {
197        self.session_id.as_str()
198    }
199
200    pub fn status(&self) -> &JobStatus {
201        &self.status
202    }
203
204    pub fn start_time(&self) -> u64 {
205        self.start_time
206    }
207
208    pub fn end_time(&self) -> u64 {
209        self.end_time
210    }
211
212    pub fn stage_count(&self) -> usize {
213        self.stages.len()
214    }
215
216    pub fn next_task_id(&mut self) -> usize {
217        let new_tid = self.task_id_gen;
218        self.task_id_gen += 1;
219        new_tid
220    }
221
222    pub(crate) fn stages(&self) -> &HashMap<usize, ExecutionStage> {
223        &self.stages
224    }
225
226    /// An ExecutionGraph is successful if all its stages are successful
227    pub fn is_successful(&self) -> bool {
228        self.stages
229            .values()
230            .all(|s| matches!(s, ExecutionStage::Successful(_)))
231    }
232
233    pub fn is_complete(&self) -> bool {
234        self.stages
235            .values()
236            .all(|s| matches!(s, ExecutionStage::Successful(_)))
237    }
238
239    /// Revive the execution graph by converting the resolved stages to running stages
240    /// If any stages are converted, return true; else false.
241    pub fn revive(&mut self) -> bool {
242        let running_stages = self
243            .stages
244            .values()
245            .filter_map(|stage| {
246                if let ExecutionStage::Resolved(resolved_stage) = stage {
247                    Some(resolved_stage.to_running())
248                } else {
249                    None
250                }
251            })
252            .collect::<Vec<_>>();
253
254        if running_stages.is_empty() {
255            false
256        } else {
257            for running_stage in running_stages {
258                self.stages.insert(
259                    running_stage.stage_id,
260                    ExecutionStage::Running(running_stage),
261                );
262            }
263            true
264        }
265    }
266
267    /// Update task statuses and task metrics in the graph.
268    /// This will also push shuffle partitions to their respective shuffle read stages.
269    pub fn update_task_status(
270        &mut self,
271        executor: &ExecutorMetadata,
272        task_statuses: Vec<TaskStatus>,
273        max_task_failures: usize,
274        max_stage_failures: usize,
275    ) -> Result<Vec<QueryStageSchedulerEvent>> {
276        let job_id = self.job_id().to_owned();
277        // First of all, classify the statuses by stages
278        let mut job_task_statuses: HashMap<usize, Vec<TaskStatus>> = HashMap::new();
279        for task_status in task_statuses {
280            let stage_id = task_status.stage_id as usize;
281            let stage_task_statuses = job_task_statuses.entry(stage_id).or_default();
282            stage_task_statuses.push(task_status);
283        }
284
285        // Revive before updating due to some updates not saved
286        // It will be refined later
287        self.revive();
288
289        let current_running_stages: HashSet<usize> =
290            HashSet::from_iter(self.running_stages());
291
292        // Copy the failed stage attempts from self
293        let mut failed_stage_attempts: HashMap<usize, HashSet<usize>> = HashMap::new();
294        for (stage_id, attempts) in self.failed_stage_attempts.iter() {
295            failed_stage_attempts
296                .insert(*stage_id, HashSet::from_iter(attempts.iter().copied()));
297        }
298
299        let mut resolved_stages = HashSet::new();
300        let mut successful_stages = HashSet::new();
301        let mut failed_stages = HashMap::new();
302        let mut rollback_running_stages = HashMap::new();
303        let mut resubmit_successful_stages: HashMap<usize, HashSet<usize>> =
304            HashMap::new();
305        let mut reset_running_stages: HashMap<usize, HashSet<usize>> = HashMap::new();
306
307        for (stage_id, stage_task_statuses) in job_task_statuses {
308            if let Some(stage) = self.stages.get_mut(&stage_id) {
309                if let ExecutionStage::Running(running_stage) = stage {
310                    let mut locations = vec![];
311                    for task_status in stage_task_statuses.into_iter() {
312                        let task_stage_attempt_num =
313                            task_status.stage_attempt_num as usize;
314                        if task_stage_attempt_num < running_stage.stage_attempt_num {
315                            warn!("Ignore TaskStatus update with TID {} as it's from Stage {}.{} and there is a more recent stage attempt {}.{} running",
316                                    task_status.task_id, stage_id, task_stage_attempt_num, stage_id, running_stage.stage_attempt_num);
317                            continue;
318                        }
319                        let partition_id = task_status.clone().partition_id as usize;
320                        let task_identity = format!(
321                            "TID {} {}/{}.{}/{}",
322                            task_status.task_id,
323                            job_id,
324                            stage_id,
325                            task_stage_attempt_num,
326                            partition_id
327                        );
328                        let operator_metrics = task_status.metrics.clone();
329
330                        if !running_stage
331                            .update_task_info(partition_id, task_status.clone())
332                        {
333                            continue;
334                        }
335
336                        if let Some(task_status::Status::Failed(failed_task)) =
337                            task_status.status
338                        {
339                            let failed_reason = failed_task.failed_reason;
340
341                            match failed_reason {
342                                Some(FailedReason::FetchPartitionError(
343                                    fetch_partiton_error,
344                                )) => {
345                                    let failed_attempts = failed_stage_attempts
346                                        .entry(stage_id)
347                                        .or_default();
348                                    failed_attempts.insert(task_stage_attempt_num);
349                                    if failed_attempts.len() < max_stage_failures {
350                                        let map_stage_id =
351                                            fetch_partiton_error.map_stage_id as usize;
352                                        let map_partition_id = fetch_partiton_error
353                                            .map_partition_id
354                                            as usize;
355                                        let executor_id =
356                                            fetch_partiton_error.executor_id;
357
358                                        if !failed_stages.is_empty() {
359                                            let error_msg = format!(
360                                                "Stages was marked failed, ignore FetchPartitionError from task {task_identity}");
361                                            warn!("{}", error_msg);
362                                        } else {
363                                            // There are different removal strategies here.
364                                            // We can choose just remove the map_partition_id in the FetchPartitionError, when resubmit the input stage, there are less tasks
365                                            // need to rerun, but this might miss many more bad input partitions, lead to more stage level retries in following.
366                                            // Here we choose remove all the bad input partitions which match the same executor id in this single input stage.
367                                            // There are other more aggressive approaches, like considering the executor is lost and check all the running stages in this graph.
368                                            // Or count the fetch failure number on executor and mark the executor lost globally.
369                                            let removed_map_partitions = running_stage
370                                                .remove_input_partitions(
371                                                    map_stage_id,
372                                                    map_partition_id,
373                                                    &executor_id,
374                                                )?;
375
376                                            let failure_reasons = rollback_running_stages
377                                                .entry(stage_id)
378                                                .or_insert_with(HashSet::new);
379                                            failure_reasons.insert(executor_id);
380
381                                            let missing_inputs =
382                                                resubmit_successful_stages
383                                                    .entry(map_stage_id)
384                                                    .or_default();
385                                            missing_inputs.extend(removed_map_partitions);
386                                            warn!("Need to resubmit the current running Stage {} and its map Stage {} due to FetchPartitionError from task {}",
387                                                    stage_id, map_stage_id, task_identity)
388                                        }
389                                    } else {
390                                        let error_msg = format!(
391                                            "Stage {} has failed {} times, \
392                                            most recent failure reason: {:?}",
393                                            stage_id,
394                                            max_stage_failures,
395                                            failed_task.error
396                                        );
397                                        error!("{}", error_msg);
398                                        failed_stages.insert(stage_id, error_msg);
399                                    }
400                                }
401                                Some(FailedReason::ExecutionError(_)) => {
402                                    failed_stages.insert(stage_id, failed_task.error);
403                                }
404                                Some(_) => {
405                                    if failed_task.retryable
406                                        && failed_task.count_to_failures
407                                    {
408                                        if running_stage.task_failure_number(partition_id)
409                                            < max_task_failures
410                                        {
411                                            // TODO add new struct to track all the failed task infos
412                                            // The failure TaskInfo is ignored and set to None here
413                                            running_stage.reset_task_info(partition_id);
414                                        } else {
415                                            let error_msg = format!(
416                                                "Task {} in Stage {} failed {} times, fail the stage, most recent failure reason: {:?}",
417                                                partition_id, stage_id, max_task_failures, failed_task.error
418                                            );
419                                            error!("{}", error_msg);
420                                            failed_stages.insert(stage_id, error_msg);
421                                        }
422                                    } else if failed_task.retryable {
423                                        // TODO add new struct to track all the failed task infos
424                                        // The failure TaskInfo is ignored and set to None here
425                                        running_stage.reset_task_info(partition_id);
426                                    }
427                                }
428                                None => {
429                                    let error_msg = format!(
430                                        "Task {partition_id} in Stage {stage_id} failed with unknown failure reasons, fail the stage");
431                                    error!("{}", error_msg);
432                                    failed_stages.insert(stage_id, error_msg);
433                                }
434                            }
435                        } else if let Some(task_status::Status::Successful(
436                            successful_task,
437                        )) = task_status.status
438                        {
439                            // update task metrics for successfu task
440                            running_stage
441                                .update_task_metrics(partition_id, operator_metrics)?;
442
443                            locations.append(&mut partition_to_location(
444                                &job_id,
445                                partition_id,
446                                stage_id,
447                                executor,
448                                successful_task.partitions,
449                            ));
450                        } else {
451                            warn!(
452                                "The task {}'s status is invalid for updating",
453                                task_identity
454                            );
455                        }
456                    }
457
458                    let is_final_successful = running_stage.is_successful()
459                        && !reset_running_stages.contains_key(&stage_id);
460                    if is_final_successful {
461                        successful_stages.insert(stage_id);
462                        // if this stage is final successful, we want to combine the stage metrics to plan's metric set and print out the plan
463                        if let Some(stage_metrics) = running_stage.stage_metrics.as_ref()
464                        {
465                            print_stage_metrics(
466                                &job_id,
467                                stage_id,
468                                running_stage.plan.as_ref(),
469                                stage_metrics,
470                            );
471                        }
472                    }
473
474                    let output_links = running_stage.output_links.clone();
475                    resolved_stages.extend(
476                        &mut self
477                            .update_stage_output_links(
478                                stage_id,
479                                is_final_successful,
480                                locations,
481                                output_links,
482                            )?
483                            .into_iter(),
484                    );
485                } else if let ExecutionStage::UnResolved(unsolved_stage) = stage {
486                    for task_status in stage_task_statuses.into_iter() {
487                        let task_stage_attempt_num =
488                            task_status.stage_attempt_num as usize;
489                        let partition_id = task_status.clone().partition_id as usize;
490                        let task_identity = format!(
491                            "TID {} {}/{}.{}/{}",
492                            task_status.task_id,
493                            job_id,
494                            stage_id,
495                            task_stage_attempt_num,
496                            partition_id
497                        );
498                        let mut should_ignore = true;
499                        // handle delayed failed tasks if the stage's next attempt is still in UnResolved status.
500                        if let Some(task_status::Status::Failed(failed_task)) =
501                            task_status.status
502                        {
503                            if unsolved_stage.stage_attempt_num - task_stage_attempt_num
504                                == 1
505                            {
506                                let failed_reason = failed_task.failed_reason;
507                                match failed_reason {
508                                    Some(FailedReason::ExecutionError(_)) => {
509                                        should_ignore = false;
510                                        failed_stages.insert(stage_id, failed_task.error);
511                                    }
512                                    Some(FailedReason::FetchPartitionError(
513                                        fetch_partiton_error,
514                                    )) if failed_stages.is_empty()
515                                        && current_running_stages.contains(
516                                            &(fetch_partiton_error.map_stage_id as usize),
517                                        )
518                                        && !unsolved_stage
519                                            .last_attempt_failure_reasons
520                                            .contains(
521                                                &fetch_partiton_error.executor_id,
522                                            ) =>
523                                    {
524                                        should_ignore = false;
525                                        unsolved_stage
526                                            .last_attempt_failure_reasons
527                                            .insert(
528                                                fetch_partiton_error.executor_id.clone(),
529                                            );
530                                        let map_stage_id =
531                                            fetch_partiton_error.map_stage_id as usize;
532                                        let map_partition_id = fetch_partiton_error
533                                            .map_partition_id
534                                            as usize;
535                                        let executor_id =
536                                            fetch_partiton_error.executor_id;
537                                        let removed_map_partitions = unsolved_stage
538                                            .remove_input_partitions(
539                                                map_stage_id,
540                                                map_partition_id,
541                                                &executor_id,
542                                            )?;
543
544                                        let missing_inputs = reset_running_stages
545                                            .entry(map_stage_id)
546                                            .or_default();
547                                        missing_inputs.extend(removed_map_partitions);
548                                        warn!("Need to reset the current running Stage {} due to late come FetchPartitionError from its parent stage {} of task {}",
549                                                    map_stage_id, stage_id, task_identity);
550
551                                        // If the previous other task updates had already mark the map stage success, need to remove it.
552                                        if successful_stages.contains(&map_stage_id) {
553                                            successful_stages.remove(&map_stage_id);
554                                        }
555                                        if resolved_stages.contains(&stage_id) {
556                                            resolved_stages.remove(&stage_id);
557                                        }
558                                    }
559                                    _ => {}
560                                }
561                            }
562                        }
563                        if should_ignore {
564                            warn!("Ignore TaskStatus update of task with TID {} as the Stage {}/{} is in UnResolved status", task_identity, job_id, stage_id);
565                        }
566                    }
567                } else {
568                    warn!(
569                        "Stage {}/{} is not in running when updating the status of tasks {:?}",
570                        job_id,
571                        stage_id,
572                        stage_task_statuses.into_iter().map(|task_status| task_status.partition_id).collect::<Vec<_>>(),
573                    );
574                }
575            } else {
576                return Err(KapotError::Internal(format!(
577                    "Invalid stage ID {stage_id} for job {job_id}"
578                )));
579            }
580        }
581
582        // Update failed stage attempts back to self
583        for (stage_id, attempts) in failed_stage_attempts.iter() {
584            self.failed_stage_attempts
585                .insert(*stage_id, HashSet::from_iter(attempts.iter().copied()));
586        }
587
588        for (stage_id, missing_parts) in &resubmit_successful_stages {
589            if let Some(stage) = self.stages.get_mut(stage_id) {
590                if let ExecutionStage::Successful(success_stage) = stage {
591                    for partition in missing_parts {
592                        if *partition > success_stage.partitions {
593                            return Err(KapotError::Internal(format!(
594                                "Invalid partition ID {} in map stage {}",
595                                *partition, stage_id
596                            )));
597                        }
598                        let task_info = &mut success_stage.task_infos[*partition];
599                        // Update the task info to failed
600                        task_info.task_status = task_status::Status::Failed(FailedTask {
601                            error: "FetchPartitionError in parent stage".to_owned(),
602                            retryable: true,
603                            count_to_failures: false,
604                            failed_reason: Some(FailedReason::ResultLost(ResultLost {})),
605                        });
606                    }
607                } else {
608                    warn!(
609                        "Stage {}/{} is not in Successful state when try to resubmit this stage. ",
610                        job_id,
611                        stage_id);
612                }
613            } else {
614                return Err(KapotError::Internal(format!(
615                    "Invalid stage ID {stage_id} for job {job_id}"
616                )));
617            }
618        }
619
620        for (stage_id, missing_parts) in &reset_running_stages {
621            if let Some(stage) = self.stages.get_mut(stage_id) {
622                if let ExecutionStage::Running(running_stage) = stage {
623                    for partition in missing_parts {
624                        if *partition > running_stage.partitions {
625                            return Err(KapotError::Internal(format!(
626                                "Invalid partition ID {} in map stage {}",
627                                *partition, stage_id
628                            )));
629                        }
630                        running_stage.reset_task_info(*partition);
631                    }
632                } else {
633                    warn!(
634                        "Stage {}/{} is not in Running state when try to reset the running task. ",
635                        job_id,
636                        stage_id);
637                }
638            } else {
639                return Err(KapotError::Internal(format!(
640                    "Invalid stage ID {stage_id} for job {job_id}"
641                )));
642            }
643        }
644
645        self.processing_stages_update(UpdatedStages {
646            resolved_stages,
647            successful_stages,
648            failed_stages,
649            rollback_running_stages,
650            resubmit_successful_stages: resubmit_successful_stages
651                .keys()
652                .cloned()
653                .collect(),
654        })
655    }
656
657    /// Processing stage status update after task status changing
658    fn processing_stages_update(
659        &mut self,
660        updated_stages: UpdatedStages,
661    ) -> Result<Vec<QueryStageSchedulerEvent>> {
662        let job_id = self.job_id().to_owned();
663        let mut has_resolved = false;
664        let mut job_err_msg = "".to_owned();
665
666        for stage_id in updated_stages.resolved_stages {
667            self.resolve_stage(stage_id)?;
668            has_resolved = true;
669        }
670
671        for stage_id in updated_stages.successful_stages {
672            self.succeed_stage(stage_id);
673        }
674
675        // Fail the stage and also abort the job
676        for (stage_id, err_msg) in &updated_stages.failed_stages {
677            job_err_msg =
678                format!("Job failed due to stage {stage_id} failed: {err_msg}\n");
679        }
680
681        let mut events = vec![];
682        // Only handle the rollback logic when there are no failed stages
683        if updated_stages.failed_stages.is_empty() {
684            let mut running_tasks_to_cancel = vec![];
685            for (stage_id, failure_reasons) in updated_stages.rollback_running_stages {
686                let tasks = self.rollback_running_stage(stage_id, failure_reasons)?;
687                running_tasks_to_cancel.extend(tasks);
688            }
689
690            for stage_id in updated_stages.resubmit_successful_stages {
691                self.rerun_successful_stage(stage_id);
692            }
693
694            if !running_tasks_to_cancel.is_empty() {
695                events.push(QueryStageSchedulerEvent::CancelTasks(
696                    running_tasks_to_cancel,
697                ));
698            }
699        }
700
701        if !updated_stages.failed_stages.is_empty() {
702            info!("Job {} is failed", job_id);
703            self.fail_job(job_err_msg.clone());
704            events.push(QueryStageSchedulerEvent::JobRunningFailed {
705                job_id,
706                fail_message: job_err_msg,
707                queued_at: self.queued_at,
708                failed_at: timestamp_millis(),
709            });
710        } else if self.is_successful() {
711            // If this ExecutionGraph is successful, finish it
712            info!("Job {} is success, finalizing output partitions", job_id);
713            self.succeed_job()?;
714            events.push(QueryStageSchedulerEvent::JobFinished {
715                job_id,
716                queued_at: self.queued_at,
717                completed_at: timestamp_millis(),
718            });
719        } else if has_resolved {
720            events.push(QueryStageSchedulerEvent::JobUpdated(job_id))
721        }
722        Ok(events)
723    }
724
725    /// Return a Vec of resolvable stage ids
726    fn update_stage_output_links(
727        &mut self,
728        stage_id: usize,
729        is_completed: bool,
730        locations: Vec<PartitionLocation>,
731        output_links: Vec<usize>,
732    ) -> Result<Vec<usize>> {
733        let mut resolved_stages = vec![];
734        let job_id = &self.job_id;
735        if output_links.is_empty() {
736            // If `output_links` is empty, then this is a final stage
737            self.output_locations.extend(locations);
738        } else {
739            for link in output_links.iter() {
740                // If this is an intermediate stage, we need to push its `PartitionLocation`s to the parent stage
741                if let Some(linked_stage) = self.stages.get_mut(link) {
742                    if let ExecutionStage::UnResolved(linked_unresolved_stage) =
743                        linked_stage
744                    {
745                        linked_unresolved_stage
746                            .add_input_partitions(stage_id, locations.clone())?;
747
748                        // If all tasks for this stage are complete, mark the input complete in the parent stage
749                        if is_completed {
750                            linked_unresolved_stage.complete_input(stage_id);
751                        }
752
753                        // If all input partitions are ready, we can resolve any UnresolvedShuffleExec in the parent stage plan
754                        if linked_unresolved_stage.resolvable() {
755                            resolved_stages.push(linked_unresolved_stage.stage_id);
756                        }
757                    } else {
758                        return Err(KapotError::Internal(format!(
759                            "Error updating job {job_id}: The stage {link} as the output link of stage {stage_id}  should be unresolved"
760                        )));
761                    }
762                } else {
763                    return Err(KapotError::Internal(format!(
764                        "Error updating job {job_id}: Invalid output link {stage_id} for stage {link}"
765                    )));
766                }
767            }
768        }
769        Ok(resolved_stages)
770    }
771
772    /// Return all the currently running stage ids
773    pub fn running_stages(&self) -> Vec<usize> {
774        self.stages
775            .iter()
776            .filter_map(|(stage_id, stage)| {
777                if let ExecutionStage::Running(_running) = stage {
778                    Some(*stage_id)
779                } else {
780                    None
781                }
782            })
783            .collect::<Vec<_>>()
784    }
785
786    /// Return all currently running tasks along with the executor ID on which they are assigned
787    pub fn running_tasks(&self) -> Vec<RunningTaskInfo> {
788        self.stages
789            .iter()
790            .flat_map(|(_, stage)| {
791                if let ExecutionStage::Running(stage) = stage {
792                    stage
793                        .running_tasks()
794                        .into_iter()
795                        .map(|(task_id, stage_id, partition_id, executor_id)| {
796                            RunningTaskInfo {
797                                task_id,
798                                job_id: self.job_id.clone(),
799                                stage_id,
800                                partition_id,
801                                executor_id,
802                            }
803                        })
804                        .collect::<Vec<RunningTaskInfo>>()
805                } else {
806                    vec![]
807                }
808            })
809            .collect::<Vec<RunningTaskInfo>>()
810    }
811
812    /// Total number of tasks in this plan that are ready for scheduling
813    pub fn available_tasks(&self) -> usize {
814        self.stages
815            .values()
816            .map(|stage| {
817                if let ExecutionStage::Running(stage) = stage {
818                    stage.available_tasks()
819                } else {
820                    0
821                }
822            })
823            .sum()
824    }
825
826    /// Get next task that can be assigned to the given executor.
827    /// This method should only be called when the resulting task is immediately
828    /// being launched as the status will be set to Running and it will not be
829    /// available to the scheduler.
830    /// If the task is not launched the status must be reset to allow the task to
831    /// be scheduled elsewhere.
832    pub fn pop_next_task(
833        &mut self,
834        executor_id: &str,
835    ) -> Result<Option<TaskDescription>> {
836        if matches!(
837            self.status,
838            JobStatus {
839                status: Some(job_status::Status::Failed(_)),
840                ..
841            }
842        ) {
843            warn!("Call pop_next_task on failed Job");
844            return Ok(None);
845        }
846
847        let job_id = self.job_id.clone();
848        let session_id = self.session_id.clone();
849
850        let find_candidate = self.stages.iter().any(|(_stage_id, stage)| {
851            if let ExecutionStage::Running(stage) = stage {
852                stage.available_tasks() > 0
853            } else {
854                false
855            }
856        });
857        let next_task_id = if find_candidate {
858            Some(self.next_task_id())
859        } else {
860            None
861        };
862
863        let mut next_task = self.stages.iter_mut().find(|(_stage_id, stage)| {
864            if let ExecutionStage::Running(stage) = stage {
865                stage.available_tasks() > 0
866            } else {
867                false
868            }
869        }).map(|(stage_id, stage)| {
870            if let ExecutionStage::Running(stage) = stage {
871                let (partition_id, _) = stage
872                    .task_infos
873                    .iter()
874                    .enumerate()
875                    .find(|(_partition, info)| info.is_none())
876                    .ok_or_else(|| {
877                        KapotError::Internal(format!("Error getting next task for job {job_id}: Stage {stage_id} is ready but has no pending tasks"))
878                    })?;
879
880                let partition = PartitionId {
881                    job_id,
882                    stage_id: *stage_id,
883                    partition_id,
884                };
885
886                let task_id = next_task_id.unwrap();
887                let task_attempt = stage.task_failure_numbers[partition_id];
888                let task_info = TaskInfo {
889                    task_id,
890                    scheduled_time: SystemTime::now()
891                        .duration_since(UNIX_EPOCH)
892                        .unwrap()
893                        .as_millis(),
894                    // Those times will be updated when the task finish
895                    launch_time: 0,
896                    start_exec_time: 0,
897                    end_exec_time: 0,
898                    finish_time: 0,
899                    task_status: task_status::Status::Running(RunningTask {
900                        executor_id: executor_id.to_owned()
901                    }),
902                };
903
904                // Set the task info to Running for new task
905                stage.task_infos[partition_id] = Some(task_info);
906
907                Ok(TaskDescription {
908                    session_id,
909                    partition,
910                    stage_attempt_num: stage.stage_attempt_num,
911                    task_id,
912                    task_attempt,
913                    data_cache: false,
914                    plan: stage.plan.clone(),
915                })
916            } else {
917                Err(KapotError::General(format!("Stage {stage_id} is not a running stage")))
918            }
919        }).transpose()?;
920
921        // If no available tasks found in the running stage,
922        // try to find a resolved stage and convert it to the running stage
923        if next_task.is_none() {
924            if self.revive() {
925                next_task = self.pop_next_task(executor_id)?;
926            } else {
927                next_task = None;
928            }
929        }
930
931        Ok(next_task)
932    }
933
934    pub(crate) fn fetch_running_stage(
935        &mut self,
936        black_list: &[usize],
937    ) -> Option<(&mut RunningStage, &mut usize)> {
938        if matches!(
939            self.status,
940            JobStatus {
941                status: Some(job_status::Status::Failed(_)),
942                ..
943            }
944        ) {
945            warn!("Call fetch_runnable_stage on failed Job");
946            return None;
947        }
948
949        let running_stage_id = self.get_running_stage_id(black_list);
950        if let Some(running_stage_id) = running_stage_id {
951            if let Some(ExecutionStage::Running(running_stage)) =
952                self.stages.get_mut(&running_stage_id)
953            {
954                Some((running_stage, &mut self.task_id_gen))
955            } else {
956                warn!("Fail to find running stage with id {running_stage_id}");
957                None
958            }
959        } else {
960            None
961        }
962    }
963
964    fn get_running_stage_id(&mut self, black_list: &[usize]) -> Option<usize> {
965        let mut running_stage_id = self.stages.iter().find_map(|(stage_id, stage)| {
966            if black_list.contains(stage_id) {
967                None
968            } else if let ExecutionStage::Running(stage) = stage {
969                if stage.available_tasks() > 0 {
970                    Some(*stage_id)
971                } else {
972                    None
973                }
974            } else {
975                None
976            }
977        });
978
979        // If no available tasks found in the running stage,
980        // try to find a resolved stage and convert it to the running stage
981        if running_stage_id.is_none() {
982            if self.revive() {
983                running_stage_id = self.get_running_stage_id(black_list);
984            } else {
985                running_stage_id = None;
986            }
987        }
988
989        running_stage_id
990    }
991
992    pub fn update_status(&mut self, status: JobStatus) {
993        self.status = status;
994    }
995
996    pub fn output_locations(&self) -> Vec<PartitionLocation> {
997        self.output_locations.clone()
998    }
999
1000    /// Reset running and successful stages on a given executor
1001    /// This will first check the unresolved/resolved/running stages and reset the running tasks and successful tasks.
1002    /// Then it will check the successful stage and whether there are running parent stages need to read shuffle from it.
1003    /// If yes, reset the successful tasks and roll back the resolved shuffle recursively.
1004    ///
1005    /// Returns the reset stage ids and running tasks should be killed
1006    pub fn reset_stages_on_lost_executor(
1007        &mut self,
1008        executor_id: &str,
1009    ) -> Result<(HashSet<usize>, Vec<RunningTaskInfo>)> {
1010        let mut reset = HashSet::new();
1011        let mut tasks_to_cancel = vec![];
1012        loop {
1013            let reset_stage = self.reset_stages_internal(executor_id)?;
1014            if !reset_stage.0.is_empty() {
1015                reset.extend(reset_stage.0.iter());
1016                tasks_to_cancel.extend(reset_stage.1)
1017            } else {
1018                return Ok((reset, tasks_to_cancel));
1019            }
1020        }
1021    }
1022
1023    fn reset_stages_internal(
1024        &mut self,
1025        executor_id: &str,
1026    ) -> Result<(HashSet<usize>, Vec<RunningTaskInfo>)> {
1027        let job_id = self.job_id.clone();
1028        // collect the input stages that need to resubmit
1029        let mut resubmit_inputs: HashSet<usize> = HashSet::new();
1030
1031        let mut reset_running_stage = HashSet::new();
1032        let mut rollback_resolved_stages = HashSet::new();
1033        let mut rollback_running_stages = HashSet::new();
1034        let mut resubmit_successful_stages = HashSet::new();
1035
1036        let mut empty_inputs: HashMap<usize, StageOutput> = HashMap::new();
1037        // check the unresolved, resolved and running stages
1038        self.stages
1039            .iter_mut()
1040            .for_each(|(stage_id, stage)| {
1041                let stage_inputs = match stage {
1042                    ExecutionStage::UnResolved(stage) => {
1043                        &mut stage.inputs
1044                    }
1045                    ExecutionStage::Resolved(stage) => {
1046                        &mut stage.inputs
1047                    }
1048                    ExecutionStage::Running(stage) => {
1049                        let reset = stage.reset_tasks(executor_id);
1050                        if reset > 0 {
1051                            warn!(
1052                        "Reset {} tasks for running job/stage {}/{} on lost Executor {}",
1053                        reset, job_id, stage_id, executor_id
1054                        );
1055                            reset_running_stage.insert(*stage_id);
1056                        }
1057                        &mut stage.inputs
1058                    }
1059                    _ => &mut empty_inputs
1060                };
1061
1062                // For each stage input, check whether there are input locations match that executor
1063                // and calculate the resubmit input stages if the input stages are successful.
1064                let mut rollback_stage = false;
1065                stage_inputs.iter_mut().for_each(|(input_stage_id, stage_output)| {
1066                    let mut match_found = false;
1067                    stage_output.partition_locations.iter_mut().for_each(
1068                        |(_partition, locs)| {
1069                            let before_len = locs.len();
1070                            locs.retain(|loc| loc.executor_meta.id != executor_id);
1071                            if locs.len() < before_len {
1072                                match_found = true;
1073                            }
1074                        },
1075                    );
1076                    if match_found {
1077                        stage_output.complete = false;
1078                        rollback_stage = true;
1079                        resubmit_inputs.insert(*input_stage_id);
1080                    }
1081                });
1082
1083                if rollback_stage {
1084                    match stage {
1085                        ExecutionStage::Resolved(_) => {
1086                            rollback_resolved_stages.insert(*stage_id);
1087                            warn!(
1088                            "Roll back resolved job/stage {}/{} and change ShuffleReaderExec back to UnresolvedShuffleExec",
1089                            job_id, stage_id);
1090                        }
1091                        ExecutionStage::Running(_) => {
1092                            rollback_running_stages.insert(*stage_id);
1093                            warn!(
1094                            "Roll back running job/stage {}/{} and change ShuffleReaderExec back to UnresolvedShuffleExec",
1095                            job_id, stage_id);
1096                        }
1097                        _ => {}
1098                    }
1099                }
1100            });
1101
1102        // check and reset the successful stages
1103        if !resubmit_inputs.is_empty() {
1104            self.stages
1105                .iter_mut()
1106                .filter(|(stage_id, _stage)| resubmit_inputs.contains(stage_id))
1107                .filter_map(|(_stage_id, stage)| {
1108                    if let ExecutionStage::Successful(success) = stage {
1109                        Some(success)
1110                    } else {
1111                        None
1112                    }
1113                })
1114                .for_each(|stage| {
1115                    let reset = stage.reset_tasks(executor_id);
1116                    if reset > 0 {
1117                        resubmit_successful_stages.insert(stage.stage_id);
1118                        warn!(
1119                            "Reset {} tasks for successful job/stage {}/{} on lost Executor {}",
1120                            reset, job_id, stage.stage_id, executor_id
1121                        )
1122                    }
1123                });
1124        }
1125
1126        for stage_id in rollback_resolved_stages.iter() {
1127            self.rollback_resolved_stage(*stage_id)?;
1128        }
1129
1130        let mut all_running_tasks = vec![];
1131        for stage_id in rollback_running_stages.iter() {
1132            let tasks = self.rollback_running_stage(
1133                *stage_id,
1134                HashSet::from([executor_id.to_owned()]),
1135            )?;
1136            all_running_tasks.extend(tasks);
1137        }
1138
1139        for stage_id in resubmit_successful_stages.iter() {
1140            self.rerun_successful_stage(*stage_id);
1141        }
1142
1143        let mut reset_stage = HashSet::new();
1144        reset_stage.extend(reset_running_stage);
1145        reset_stage.extend(rollback_resolved_stages);
1146        reset_stage.extend(rollback_running_stages);
1147        reset_stage.extend(resubmit_successful_stages);
1148        Ok((reset_stage, all_running_tasks))
1149    }
1150
1151    /// Convert unresolved stage to be resolved
1152    pub fn resolve_stage(&mut self, stage_id: usize) -> Result<bool> {
1153        if let Some(ExecutionStage::UnResolved(stage)) = self.stages.remove(&stage_id) {
1154            self.stages
1155                .insert(stage_id, ExecutionStage::Resolved(stage.to_resolved()?));
1156            Ok(true)
1157        } else {
1158            warn!(
1159                "Fail to find a unresolved stage {}/{} to resolve",
1160                self.job_id(),
1161                stage_id
1162            );
1163            Ok(false)
1164        }
1165    }
1166
1167    /// Convert running stage to be successful
1168    pub fn succeed_stage(&mut self, stage_id: usize) -> bool {
1169        if let Some(ExecutionStage::Running(stage)) = self.stages.remove(&stage_id) {
1170            self.stages
1171                .insert(stage_id, ExecutionStage::Successful(stage.to_successful()));
1172            self.clear_stage_failure(stage_id);
1173            true
1174        } else {
1175            warn!(
1176                "Fail to find a running stage {}/{} to make it success",
1177                self.job_id(),
1178                stage_id
1179            );
1180            false
1181        }
1182    }
1183
1184    /// Convert running stage to be failed
1185    pub fn fail_stage(&mut self, stage_id: usize, err_msg: String) -> bool {
1186        if let Some(ExecutionStage::Running(stage)) = self.stages.remove(&stage_id) {
1187            self.stages
1188                .insert(stage_id, ExecutionStage::Failed(stage.to_failed(err_msg)));
1189            true
1190        } else {
1191            info!(
1192                "Fail to find a running stage {}/{} to fail",
1193                self.job_id(),
1194                stage_id
1195            );
1196            false
1197        }
1198    }
1199
1200    /// Convert running stage to be unresolved,
1201    /// Returns a Vec of RunningTaskInfo for running tasks in this stage.
1202    pub fn rollback_running_stage(
1203        &mut self,
1204        stage_id: usize,
1205        failure_reasons: HashSet<String>,
1206    ) -> Result<Vec<RunningTaskInfo>> {
1207        if let Some(ExecutionStage::Running(stage)) = self.stages.remove(&stage_id) {
1208            let running_tasks = stage
1209                .running_tasks()
1210                .into_iter()
1211                .map(
1212                    |(task_id, stage_id, partition_id, executor_id)| RunningTaskInfo {
1213                        task_id,
1214                        job_id: self.job_id.clone(),
1215                        stage_id,
1216                        partition_id,
1217                        executor_id,
1218                    },
1219                )
1220                .collect();
1221            self.stages.insert(
1222                stage_id,
1223                ExecutionStage::UnResolved(stage.to_unresolved(failure_reasons)?),
1224            );
1225            Ok(running_tasks)
1226        } else {
1227            warn!(
1228                "Fail to find a running stage {}/{} to rollback",
1229                self.job_id(),
1230                stage_id
1231            );
1232            Ok(vec![])
1233        }
1234    }
1235
1236    /// Convert resolved stage to be unresolved
1237    pub fn rollback_resolved_stage(&mut self, stage_id: usize) -> Result<bool> {
1238        if let Some(ExecutionStage::Resolved(stage)) = self.stages.remove(&stage_id) {
1239            self.stages
1240                .insert(stage_id, ExecutionStage::UnResolved(stage.to_unresolved()?));
1241            Ok(true)
1242        } else {
1243            warn!(
1244                "Fail to find a resolved stage {}/{} to rollback",
1245                self.job_id(),
1246                stage_id
1247            );
1248            Ok(false)
1249        }
1250    }
1251
1252    /// Convert successful stage to be running
1253    pub fn rerun_successful_stage(&mut self, stage_id: usize) -> bool {
1254        if let Some(ExecutionStage::Successful(stage)) = self.stages.remove(&stage_id) {
1255            self.stages
1256                .insert(stage_id, ExecutionStage::Running(stage.to_running()));
1257            true
1258        } else {
1259            warn!(
1260                "Fail to find a successful stage {}/{} to rerun",
1261                self.job_id(),
1262                stage_id
1263            );
1264            false
1265        }
1266    }
1267
1268    /// fail job with error message
1269    pub fn fail_job(&mut self, error: String) {
1270        self.status = JobStatus {
1271            job_id: self.job_id.clone(),
1272            job_name: self.job_name.clone(),
1273            status: Some(Status::Failed(FailedJob {
1274                error,
1275                queued_at: self.queued_at,
1276                started_at: self.start_time,
1277                ended_at: self.end_time,
1278            })),
1279        };
1280    }
1281
1282    /// Mark the job success
1283    pub fn succeed_job(&mut self) -> Result<()> {
1284        if !self.is_successful() {
1285            return Err(KapotError::Internal(format!(
1286                "Attempt to finalize an incomplete job {}",
1287                self.job_id()
1288            )));
1289        }
1290
1291        let partition_location = self
1292            .output_locations()
1293            .into_iter()
1294            .map(|l| l.try_into())
1295            .collect::<Result<Vec<_>>>()?;
1296
1297        self.status = JobStatus {
1298            job_id: self.job_id.clone(),
1299            job_name: self.job_name.clone(),
1300            status: Some(job_status::Status::Successful(SuccessfulJob {
1301                partition_location,
1302
1303                queued_at: self.queued_at,
1304                started_at: self.start_time,
1305                ended_at: self.end_time,
1306            })),
1307        };
1308        self.end_time = SystemTime::now()
1309            .duration_since(UNIX_EPOCH)
1310            .unwrap()
1311            .as_millis() as u64;
1312
1313        Ok(())
1314    }
1315
1316    /// Clear the stage failure count for this stage if the stage is finally success
1317    fn clear_stage_failure(&mut self, stage_id: usize) {
1318        self.failed_stage_attempts.remove(&stage_id);
1319    }
1320
1321    pub(crate) async fn decode_execution_graph<
1322        T: 'static + AsLogicalPlan,
1323        U: 'static + AsExecutionPlan,
1324    >(
1325        proto: protobuf::ExecutionGraph,
1326        codec: &KapotCodec<T, U>,
1327        session_ctx: &SessionContext,
1328    ) -> Result<ExecutionGraph> {
1329        let mut stages: HashMap<usize, ExecutionStage> = HashMap::new();
1330        for graph_stage in proto.stages {
1331            let stage_type = graph_stage.stage_type.expect("Unexpected empty stage");
1332
1333            let execution_stage = match stage_type {
1334                StageType::UnresolvedStage(stage) => {
1335                    let stage: UnresolvedStage =
1336                        UnresolvedStage::decode(stage, codec, session_ctx)?;
1337                    (stage.stage_id, ExecutionStage::UnResolved(stage))
1338                }
1339                StageType::ResolvedStage(stage) => {
1340                    let stage: ResolvedStage =
1341                        ResolvedStage::decode(stage, codec, session_ctx)?;
1342                    (stage.stage_id, ExecutionStage::Resolved(stage))
1343                }
1344                StageType::SuccessfulStage(stage) => {
1345                    let stage: SuccessfulStage =
1346                        SuccessfulStage::decode(stage, codec, session_ctx)?;
1347                    (stage.stage_id, ExecutionStage::Successful(stage))
1348                }
1349                StageType::FailedStage(stage) => {
1350                    let stage: FailedStage =
1351                        FailedStage::decode(stage, codec, session_ctx)?;
1352                    (stage.stage_id, ExecutionStage::Failed(stage))
1353                }
1354            };
1355
1356            stages.insert(execution_stage.0, execution_stage.1);
1357        }
1358
1359        let output_locations: Vec<PartitionLocation> = proto
1360            .output_locations
1361            .into_iter()
1362            .map(|loc| loc.try_into())
1363            .collect::<Result<Vec<_>>>()?;
1364
1365        let failed_stage_attempts = proto
1366            .failed_attempts
1367            .into_iter()
1368            .map(|attempt| {
1369                (
1370                    attempt.stage_id as usize,
1371                    HashSet::from_iter(
1372                        attempt
1373                            .stage_attempt_num
1374                            .into_iter()
1375                            .map(|num| num as usize),
1376                    ),
1377                )
1378            })
1379            .collect();
1380
1381        Ok(ExecutionGraph {
1382            scheduler_id: (!proto.scheduler_id.is_empty()).then_some(proto.scheduler_id),
1383            job_id: proto.job_id,
1384            job_name: proto.job_name,
1385            session_id: proto.session_id,
1386            status: proto.status.ok_or_else(|| {
1387                KapotError::Internal(
1388                    "Invalid Execution Graph: missing job status".to_owned(),
1389                )
1390            })?,
1391            queued_at: proto.queued_at,
1392            start_time: proto.start_time,
1393            end_time: proto.end_time,
1394            stages,
1395            output_partitions: proto.output_partitions as usize,
1396            output_locations,
1397            task_id_gen: proto.task_id_gen as usize,
1398            failed_stage_attempts,
1399        })
1400    }
1401
1402    /// Running stages will not be persisted so that will not be encoded.
1403    /// Running stages will be convert back to the resolved stages to be encoded and persisted
1404    pub(crate) fn encode_execution_graph<
1405        T: 'static + AsLogicalPlan,
1406        U: 'static + AsExecutionPlan,
1407    >(
1408        graph: ExecutionGraph,
1409        codec: &KapotCodec<T, U>,
1410    ) -> Result<protobuf::ExecutionGraph> {
1411        let job_id = graph.job_id().to_owned();
1412
1413        let stages = graph
1414            .stages
1415            .into_values()
1416            .map(|stage| {
1417                let stage_type = match stage {
1418                    ExecutionStage::UnResolved(stage) => {
1419                        StageType::UnresolvedStage(UnresolvedStage::encode(stage, codec)?)
1420                    }
1421                    ExecutionStage::Resolved(stage) => {
1422                        StageType::ResolvedStage(ResolvedStage::encode(stage, codec)?)
1423                    }
1424                    ExecutionStage::Running(stage) => StageType::ResolvedStage(
1425                        ResolvedStage::encode(stage.to_resolved(), codec)?,
1426                    ),
1427                    ExecutionStage::Successful(stage) => StageType::SuccessfulStage(
1428                        SuccessfulStage::encode(job_id.clone(), stage, codec)?,
1429                    ),
1430                    ExecutionStage::Failed(stage) => StageType::FailedStage(
1431                        FailedStage::encode(job_id.clone(), stage, codec)?,
1432                    ),
1433                };
1434                Ok(protobuf::ExecutionGraphStage {
1435                    stage_type: Some(stage_type),
1436                })
1437            })
1438            .collect::<Result<Vec<_>>>()?;
1439
1440        let output_locations: Vec<protobuf::PartitionLocation> = graph
1441            .output_locations
1442            .into_iter()
1443            .map(|loc| loc.try_into())
1444            .collect::<Result<Vec<_>>>()?;
1445
1446        let failed_attempts: Vec<protobuf::StageAttempts> = graph
1447            .failed_stage_attempts
1448            .into_iter()
1449            .map(|(stage_id, attempts)| {
1450                let stage_attempt_num = attempts
1451                    .into_iter()
1452                    .map(|num| num as u32)
1453                    .collect::<Vec<_>>();
1454                protobuf::StageAttempts {
1455                    stage_id: stage_id as u32,
1456                    stage_attempt_num,
1457                }
1458            })
1459            .collect::<Vec<_>>();
1460
1461        Ok(protobuf::ExecutionGraph {
1462            job_id: graph.job_id,
1463            job_name: graph.job_name,
1464            session_id: graph.session_id,
1465            status: Some(graph.status),
1466            queued_at: graph.queued_at,
1467            start_time: graph.start_time,
1468            end_time: graph.end_time,
1469            stages,
1470            output_partitions: graph.output_partitions as u64,
1471            output_locations,
1472            scheduler_id: graph.scheduler_id.unwrap_or_default(),
1473            task_id_gen: graph.task_id_gen as u32,
1474            failed_attempts,
1475        })
1476    }
1477}
1478
1479impl Debug for ExecutionGraph {
1480    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1481        let stages = self
1482            .stages
1483            .values()
1484            .map(|stage| format!("{stage:?}"))
1485            .collect::<Vec<String>>()
1486            .join("");
1487        write!(f, "ExecutionGraph[job_id={}, session_id={}, available_tasks={}, is_successful={}]\n{}",
1488               self.job_id, self.session_id, self.available_tasks(), self.is_successful(), stages)
1489    }
1490}
1491
1492pub(crate) fn create_task_info(executor_id: String, task_id: usize) -> TaskInfo {
1493    TaskInfo {
1494        task_id,
1495        scheduled_time: SystemTime::now()
1496            .duration_since(UNIX_EPOCH)
1497            .unwrap()
1498            .as_millis(),
1499        // Those times will be updated when the task finish
1500        launch_time: 0,
1501        start_exec_time: 0,
1502        end_exec_time: 0,
1503        finish_time: 0,
1504        task_status: task_status::Status::Running(RunningTask { executor_id }),
1505    }
1506}
1507
1508/// Utility for building a set of `ExecutionStage`s from
1509/// a list of `ShuffleWriterExec`.
1510///
1511/// This will infer the dependency structure for the stages
1512/// so that we can construct a DAG from the stages.
1513struct ExecutionStageBuilder {
1514    /// Stage ID which is currently being visited
1515    current_stage_id: usize,
1516    /// Map from stage ID -> List of child stage IDs
1517    stage_dependencies: HashMap<usize, Vec<usize>>,
1518    /// Map from Stage ID -> output link
1519    output_links: HashMap<usize, Vec<usize>>,
1520}
1521
1522impl ExecutionStageBuilder {
1523    pub fn new() -> Self {
1524        Self {
1525            current_stage_id: 0,
1526            stage_dependencies: HashMap::new(),
1527            output_links: HashMap::new(),
1528        }
1529    }
1530
1531    pub fn build(
1532        mut self,
1533        stages: Vec<Arc<ShuffleWriterExec>>,
1534    ) -> Result<HashMap<usize, ExecutionStage>> {
1535        let mut execution_stages: HashMap<usize, ExecutionStage> = HashMap::new();
1536        // First, build the dependency graph
1537        for stage in &stages {
1538            accept(stage.as_ref(), &mut self)?;
1539        }
1540
1541        // Now, create the execution stages
1542        for stage in stages {
1543            let stage_id = stage.stage_id();
1544            let output_links = self.output_links.remove(&stage_id).unwrap_or_default();
1545
1546            let child_stages = self
1547                .stage_dependencies
1548                .remove(&stage_id)
1549                .unwrap_or_default();
1550
1551            let stage = if child_stages.is_empty() {
1552                ExecutionStage::Resolved(ResolvedStage::new(
1553                    stage_id,
1554                    0,
1555                    stage,
1556                    output_links,
1557                    HashMap::new(),
1558                    HashSet::new(),
1559                ))
1560            } else {
1561                ExecutionStage::UnResolved(UnresolvedStage::new(
1562                    stage_id,
1563                    stage,
1564                    output_links,
1565                    child_stages,
1566                ))
1567            };
1568            execution_stages.insert(stage_id, stage);
1569        }
1570
1571        Ok(execution_stages)
1572    }
1573}
1574
1575impl ExecutionPlanVisitor for ExecutionStageBuilder {
1576    type Error = KapotError;
1577
1578    fn pre_visit(
1579        &mut self,
1580        plan: &dyn ExecutionPlan,
1581    ) -> std::result::Result<bool, Self::Error> {
1582        if let Some(shuffle_write) = plan.as_any().downcast_ref::<ShuffleWriterExec>() {
1583            self.current_stage_id = shuffle_write.stage_id();
1584        } else if let Some(unresolved_shuffle) =
1585            plan.as_any().downcast_ref::<UnresolvedShuffleExec>()
1586        {
1587            if let Some(output_links) =
1588                self.output_links.get_mut(&unresolved_shuffle.stage_id)
1589            {
1590                if !output_links.contains(&self.current_stage_id) {
1591                    output_links.push(self.current_stage_id);
1592                }
1593            } else {
1594                self.output_links
1595                    .insert(unresolved_shuffle.stage_id, vec![self.current_stage_id]);
1596            }
1597
1598            if let Some(deps) = self.stage_dependencies.get_mut(&self.current_stage_id) {
1599                if !deps.contains(&unresolved_shuffle.stage_id) {
1600                    deps.push(unresolved_shuffle.stage_id);
1601                }
1602            } else {
1603                self.stage_dependencies
1604                    .insert(self.current_stage_id, vec![unresolved_shuffle.stage_id]);
1605            }
1606        }
1607        Ok(true)
1608    }
1609}
1610
1611/// Represents the basic unit of work for the kapot executor. Will execute
1612/// one partition of one stage on one task slot.
1613#[derive(Clone)]
1614pub struct TaskDescription {
1615    pub session_id: String,
1616    pub partition: PartitionId,
1617    pub stage_attempt_num: usize,
1618    pub task_id: usize,
1619    pub task_attempt: usize,
1620    pub data_cache: bool,
1621    pub plan: Arc<dyn ExecutionPlan>,
1622}
1623
1624impl Debug for TaskDescription {
1625    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1626        let plan = DisplayableExecutionPlan::new(self.plan.as_ref()).indent(false);
1627        write!(
1628            f,
1629            "TaskDescription[session_id: {},job: {}, stage: {}.{}, partition: {} task_id {}, task attempt {}, data cache {}]\n{}",
1630            self.session_id,
1631            self.partition.job_id,
1632            self.partition.stage_id,
1633            self.stage_attempt_num,
1634            self.partition.partition_id,
1635            self.task_id,
1636            self.task_attempt,
1637            self.data_cache,
1638            plan
1639        )
1640    }
1641}
1642
1643impl TaskDescription {
1644    pub fn get_output_partition_number(&self) -> usize {
1645        let shuffle_writer = self
1646            .plan
1647            .as_any()
1648            .downcast_ref::<ShuffleWriterExec>()
1649            .unwrap();
1650        shuffle_writer
1651            .shuffle_output_partitioning()
1652            .map(|partitioning| partitioning.partition_count())
1653            .unwrap_or_else(|| 1)
1654    }
1655}
1656
1657fn partition_to_location(
1658    job_id: &str,
1659    map_partition_id: usize,
1660    stage_id: usize,
1661    executor: &ExecutorMetadata,
1662    shuffles: Vec<ShuffleWritePartition>,
1663) -> Vec<PartitionLocation> {
1664    shuffles
1665        .into_iter()
1666        .map(|shuffle| PartitionLocation {
1667            map_partition_id,
1668            partition_id: PartitionId {
1669                job_id: job_id.to_owned(),
1670                stage_id,
1671                partition_id: shuffle.partition_id as usize,
1672            },
1673            executor_meta: executor.clone(),
1674            partition_stats: PartitionStats::new(
1675                Some(shuffle.num_rows),
1676                Some(shuffle.num_batches),
1677                Some(shuffle.num_bytes),
1678            ),
1679            path: shuffle.path,
1680        })
1681        .collect()
1682}
1683
1684#[cfg(test)]
1685mod test {
1686    use std::collections::HashSet;
1687
1688    use crate::scheduler_server::event::QueryStageSchedulerEvent;
1689    use kapot_core::error::Result;
1690    use kapot_core::serde::protobuf::{
1691        self, failed_task, job_status, ExecutionError, FailedTask, FetchPartitionError,
1692        IoError, JobStatus, TaskKilled,
1693    };
1694
1695    use crate::state::execution_graph::ExecutionGraph;
1696    use crate::test_utils::{
1697        mock_completed_task, mock_executor, mock_failed_task,
1698        revive_graph_and_complete_next_stage,
1699        revive_graph_and_complete_next_stage_with_executor, test_aggregation_plan,
1700        test_coalesce_plan, test_join_plan, test_two_aggregations_plan,
1701        test_union_all_plan, test_union_plan,
1702    };
1703
1704    #[tokio::test]
1705    async fn test_drain_tasks() -> Result<()> {
1706        let mut agg_graph = test_aggregation_plan(4).await;
1707
1708        println!("Graph: {agg_graph:?}");
1709
1710        drain_tasks(&mut agg_graph)?;
1711
1712        assert!(
1713            agg_graph.is_successful(),
1714            "Failed to complete aggregation plan"
1715        );
1716
1717        let mut coalesce_graph = test_coalesce_plan(4).await;
1718
1719        drain_tasks(&mut coalesce_graph)?;
1720
1721        assert!(
1722            coalesce_graph.is_successful(),
1723            "Failed to complete coalesce plan"
1724        );
1725
1726        let mut join_graph = test_join_plan(4).await;
1727
1728        drain_tasks(&mut join_graph)?;
1729
1730        println!("{join_graph:?}");
1731
1732        assert!(join_graph.is_successful(), "Failed to complete join plan");
1733
1734        let mut union_all_graph = test_union_all_plan(4).await;
1735
1736        drain_tasks(&mut union_all_graph)?;
1737
1738        println!("{union_all_graph:?}");
1739
1740        assert!(
1741            union_all_graph.is_successful(),
1742            "Failed to complete union plan"
1743        );
1744
1745        let mut union_graph = test_union_plan(4).await;
1746
1747        drain_tasks(&mut union_graph)?;
1748
1749        println!("{union_graph:?}");
1750
1751        assert!(union_graph.is_successful(), "Failed to complete union plan");
1752
1753        Ok(())
1754    }
1755
1756    #[tokio::test]
1757    async fn test_finalize() -> Result<()> {
1758        let mut agg_graph = test_aggregation_plan(4).await;
1759
1760        drain_tasks(&mut agg_graph)?;
1761
1762        let status = agg_graph.status();
1763
1764        assert!(matches!(
1765            status,
1766            protobuf::JobStatus {
1767                status: Some(job_status::Status::Successful(_)),
1768                ..
1769            }
1770        ));
1771
1772        let outputs = agg_graph.output_locations();
1773
1774        assert_eq!(outputs.len(), agg_graph.output_partitions);
1775
1776        for location in outputs {
1777            assert_eq!(location.executor_meta.host, "localhost2".to_owned());
1778        }
1779
1780        Ok(())
1781    }
1782
1783    #[tokio::test]
1784    async fn test_reset_completed_stage_executor_lost() -> Result<()> {
1785        let executor1 = mock_executor("executor-id1".to_string());
1786        let executor2 = mock_executor("executor-id2".to_string());
1787        let mut join_graph = test_join_plan(4).await;
1788
1789        // With the improvement of https://github.com/apache/arrow-datafusion/pull/4122,
1790        // unnecessary RepartitionExec can be removed
1791        assert_eq!(join_graph.stage_count(), 4);
1792        assert_eq!(join_graph.available_tasks(), 0);
1793
1794        // Call revive to move the two leaf Resolved stages to Running
1795        join_graph.revive();
1796
1797        assert_eq!(join_graph.stage_count(), 4);
1798        assert_eq!(join_graph.available_tasks(), 4);
1799
1800        // Complete the first stage
1801        revive_graph_and_complete_next_stage_with_executor(&mut join_graph, &executor1)?;
1802
1803        // Complete the second stage
1804        revive_graph_and_complete_next_stage_with_executor(&mut join_graph, &executor2)?;
1805
1806        join_graph.revive();
1807        // There are 4 tasks pending schedule for the 3rd stage
1808        assert_eq!(join_graph.available_tasks(), 4);
1809
1810        // Complete 1 task
1811        if let Some(task) = join_graph.pop_next_task(&executor1.id)? {
1812            let task_status = mock_completed_task(task, &executor1.id);
1813            join_graph.update_task_status(&executor1, vec![task_status], 1, 1)?;
1814        }
1815        // Mock 1 running task
1816        let _task = join_graph.pop_next_task(&executor1.id)?;
1817
1818        let reset = join_graph.reset_stages_on_lost_executor(&executor1.id)?;
1819
1820        // Two stages were reset, 1 Running stage rollback to Unresolved and 1 Completed stage move to Running
1821        assert_eq!(reset.0.len(), 2);
1822        assert_eq!(join_graph.available_tasks(), 2);
1823
1824        drain_tasks(&mut join_graph)?;
1825        assert!(join_graph.is_successful(), "Failed to complete join plan");
1826
1827        Ok(())
1828    }
1829
1830    #[tokio::test]
1831    async fn test_reset_resolved_stage_executor_lost() -> Result<()> {
1832        let executor1 = mock_executor("executor-id1".to_string());
1833        let executor2 = mock_executor("executor-id2".to_string());
1834        let mut join_graph = test_join_plan(4).await;
1835
1836        assert_eq!(join_graph.stage_count(), 4);
1837        assert_eq!(join_graph.available_tasks(), 0);
1838
1839        // Call revive to move the two leaf Resolved stages to Running
1840        join_graph.revive();
1841
1842        assert_eq!(join_graph.stage_count(), 4);
1843        assert_eq!(join_graph.available_tasks(), 4);
1844
1845        // Complete the first stage
1846        assert_eq!(revive_graph_and_complete_next_stage(&mut join_graph)?, 2);
1847
1848        // Complete the second stage
1849        assert_eq!(
1850            revive_graph_and_complete_next_stage_with_executor(
1851                &mut join_graph,
1852                &executor2
1853            )?,
1854            2
1855        );
1856
1857        // There are 0 tasks pending schedule now
1858        assert_eq!(join_graph.available_tasks(), 0);
1859
1860        let reset = join_graph.reset_stages_on_lost_executor(&executor1.id)?;
1861
1862        // Two stages were reset, 1 Resolved stage rollback to Unresolved and 1 Completed stage move to Running
1863        assert_eq!(reset.0.len(), 2);
1864        assert_eq!(join_graph.available_tasks(), 2);
1865
1866        drain_tasks(&mut join_graph)?;
1867        assert!(join_graph.is_successful(), "Failed to complete join plan");
1868
1869        Ok(())
1870    }
1871
1872    #[tokio::test]
1873    async fn test_task_update_after_reset_stage() -> Result<()> {
1874        let executor1 = mock_executor("executor-id1".to_string());
1875        let executor2 = mock_executor("executor-id2".to_string());
1876        let mut agg_graph = test_aggregation_plan(4).await;
1877
1878        assert_eq!(agg_graph.stage_count(), 2);
1879        assert_eq!(agg_graph.available_tasks(), 0);
1880
1881        // Call revive to move the leaf Resolved stages to Running
1882        agg_graph.revive();
1883
1884        assert_eq!(agg_graph.stage_count(), 2);
1885        assert_eq!(agg_graph.available_tasks(), 2);
1886
1887        // Complete the first stage
1888        revive_graph_and_complete_next_stage_with_executor(&mut agg_graph, &executor1)?;
1889
1890        // 1st task in the second stage
1891        if let Some(task) = agg_graph.pop_next_task(&executor2.id)? {
1892            let task_status = mock_completed_task(task, &executor2.id);
1893            agg_graph.update_task_status(&executor2, vec![task_status], 1, 1)?;
1894        }
1895
1896        // 2rd task in the second stage
1897        if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
1898            let task_status = mock_completed_task(task, &executor1.id);
1899            agg_graph.update_task_status(&executor1, vec![task_status], 1, 1)?;
1900        }
1901
1902        // 3rd task in the second stage, scheduled but not completed
1903        let task = agg_graph.pop_next_task(&executor1.id)?;
1904
1905        // There is 1 task pending schedule now
1906        assert_eq!(agg_graph.available_tasks(), 1);
1907
1908        let reset = agg_graph.reset_stages_on_lost_executor(&executor1.id)?;
1909
1910        // 3rd task status update comes later.
1911        let task_status = mock_completed_task(task.unwrap(), &executor1.id);
1912        agg_graph.update_task_status(&executor1, vec![task_status], 1, 1)?;
1913
1914        // Two stages were reset, 1 Running stage rollback to Unresolved and 1 Completed stage move to Running
1915        assert_eq!(reset.0.len(), 2);
1916        assert_eq!(agg_graph.available_tasks(), 2);
1917
1918        // Call the reset again
1919        let reset = agg_graph.reset_stages_on_lost_executor(&executor1.id)?;
1920        assert_eq!(reset.0.len(), 0);
1921        assert_eq!(agg_graph.available_tasks(), 2);
1922
1923        drain_tasks(&mut agg_graph)?;
1924        assert!(agg_graph.is_successful(), "Failed to complete agg plan");
1925
1926        Ok(())
1927    }
1928
1929    #[tokio::test]
1930    async fn test_do_not_retry_killed_task() -> Result<()> {
1931        let executor = mock_executor("executor-id-123".to_string());
1932        let mut agg_graph = test_aggregation_plan(4).await;
1933        // Call revive to move the leaf Resolved stages to Running
1934        agg_graph.revive();
1935
1936        // Complete the first stage
1937        revive_graph_and_complete_next_stage(&mut agg_graph)?;
1938
1939        // 1st task in the second stage
1940        let task1 = agg_graph.pop_next_task(&executor.id)?.unwrap();
1941        let task_status1 = mock_completed_task(task1, &executor.id);
1942
1943        // 2rd task in the second stage
1944        let task2 = agg_graph.pop_next_task(&executor.id)?.unwrap();
1945        let task_status2 = mock_failed_task(
1946            task2,
1947            FailedTask {
1948                error: "Killed".to_string(),
1949                retryable: false,
1950                count_to_failures: false,
1951                failed_reason: Some(failed_task::FailedReason::TaskKilled(TaskKilled {})),
1952            },
1953        );
1954
1955        agg_graph.update_task_status(
1956            &executor,
1957            vec![task_status1, task_status2],
1958            4,
1959            4,
1960        )?;
1961
1962        assert_eq!(agg_graph.available_tasks(), 2);
1963        drain_tasks(&mut agg_graph)?;
1964        assert_eq!(agg_graph.available_tasks(), 0);
1965
1966        assert!(
1967            !agg_graph.is_successful(),
1968            "Expected the agg graph can not complete"
1969        );
1970        Ok(())
1971    }
1972
1973    #[tokio::test]
1974    async fn test_max_task_failed_count() -> Result<()> {
1975        let executor = mock_executor("executor-id2".to_string());
1976        let mut agg_graph = test_aggregation_plan(2).await;
1977        // Call revive to move the leaf Resolved stages to Running
1978        agg_graph.revive();
1979
1980        // Complete the first stage
1981        revive_graph_and_complete_next_stage(&mut agg_graph)?;
1982
1983        // 1st task in the second stage
1984        let task1 = agg_graph.pop_next_task(&executor.id)?.unwrap();
1985        let task_status1 = mock_completed_task(task1, &executor.id);
1986
1987        // 2rd task in the second stage, failed due to IOError
1988        let task2 = agg_graph.pop_next_task(&executor.id)?.unwrap();
1989        let task_status2 = mock_failed_task(
1990            task2.clone(),
1991            FailedTask {
1992                error: "IOError".to_string(),
1993                retryable: true,
1994                count_to_failures: true,
1995                failed_reason: Some(failed_task::FailedReason::IoError(IoError {})),
1996            },
1997        );
1998
1999        agg_graph.update_task_status(
2000            &executor,
2001            vec![task_status1, task_status2],
2002            4,
2003            4,
2004        )?;
2005
2006        assert_eq!(agg_graph.available_tasks(), 1);
2007
2008        let mut last_attempt = 0;
2009        // 2rd task's attempts
2010        for attempt in 1..5 {
2011            if let Some(task2_attempt) = agg_graph.pop_next_task(&executor.id)? {
2012                assert_eq!(
2013                    task2_attempt.partition.partition_id,
2014                    task2.partition.partition_id
2015                );
2016                assert_eq!(task2_attempt.task_attempt, attempt);
2017                last_attempt = task2_attempt.task_attempt;
2018                let task_status = mock_failed_task(
2019                    task2_attempt.clone(),
2020                    FailedTask {
2021                        error: "IOError".to_string(),
2022                        retryable: true,
2023                        count_to_failures: true,
2024                        failed_reason: Some(failed_task::FailedReason::IoError(
2025                            IoError {},
2026                        )),
2027                    },
2028                );
2029                agg_graph.update_task_status(&executor, vec![task_status], 4, 4)?;
2030            }
2031        }
2032
2033        assert!(
2034            matches!(
2035                agg_graph.status,
2036                JobStatus {
2037                    status: Some(job_status::Status::Failed(_)),
2038                    ..
2039                }
2040            ),
2041            "Expected job status to be Failed"
2042        );
2043
2044        assert_eq!(last_attempt, 3);
2045
2046        let failure_reason = format!("{:?}", agg_graph.status);
2047        assert!(failure_reason.contains("Task 1 in Stage 2 failed 4 times, fail the stage, most recent failure reason"));
2048        assert!(failure_reason.contains("IOError"));
2049        assert!(!agg_graph.is_successful());
2050
2051        Ok(())
2052    }
2053
2054    #[tokio::test]
2055    async fn test_long_delayed_failed_task_after_executor_lost() -> Result<()> {
2056        let executor1 = mock_executor("executor-id1".to_string());
2057        let executor2 = mock_executor("executor-id2".to_string());
2058        let mut agg_graph = test_aggregation_plan(4).await;
2059        // Call revive to move the leaf Resolved stages to Running
2060        agg_graph.revive();
2061
2062        // Complete the Stage 1
2063        revive_graph_and_complete_next_stage_with_executor(&mut agg_graph, &executor1)?;
2064
2065        // 1st task in the Stage 2
2066        if let Some(task) = agg_graph.pop_next_task(&executor2.id)? {
2067            let task_status = mock_completed_task(task, &executor2.id);
2068            agg_graph.update_task_status(&executor2, vec![task_status], 1, 1)?;
2069        }
2070
2071        // 2rd task in the Stage 2
2072        if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
2073            let task_status = mock_completed_task(task, &executor1.id);
2074            agg_graph.update_task_status(&executor1, vec![task_status], 1, 1)?;
2075        }
2076
2077        // 3rd task in the Stage 2, scheduled on executor 2 but not completed
2078        let task = agg_graph.pop_next_task(&executor2.id)?;
2079
2080        // There is 1 task pending schedule now
2081        assert_eq!(agg_graph.available_tasks(), 1);
2082
2083        // executor 1 lost
2084        let reset = agg_graph.reset_stages_on_lost_executor(&executor1.id)?;
2085
2086        // Two stages were reset, Stage 2 rollback to Unresolved and Stage 1 move to Running
2087        assert_eq!(reset.0.len(), 2);
2088        assert_eq!(agg_graph.available_tasks(), 2);
2089
2090        // Complete the Stage 1 again
2091        revive_graph_and_complete_next_stage_with_executor(&mut agg_graph, &executor1)?;
2092
2093        // Stage 2 move to Running
2094        agg_graph.revive();
2095        assert_eq!(agg_graph.available_tasks(), 4);
2096
2097        // 3rd task in Stage 2 update comes very late due to runtime execution error.
2098        let task_status = mock_failed_task(
2099            task.unwrap(),
2100            FailedTask {
2101                error: "ExecutionError".to_string(),
2102                retryable: false,
2103                count_to_failures: false,
2104                failed_reason: Some(failed_task::FailedReason::ExecutionError(
2105                    ExecutionError {},
2106                )),
2107            },
2108        );
2109
2110        // This long delayed failed task should not failure the stage/job and should not trigger any query stage events
2111        let query_stage_events =
2112            agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
2113        assert!(query_stage_events.is_empty());
2114
2115        drain_tasks(&mut agg_graph)?;
2116        assert!(agg_graph.is_successful(), "Failed to complete agg plan");
2117
2118        Ok(())
2119    }
2120
2121    #[tokio::test]
2122    async fn test_normal_fetch_failure() -> Result<()> {
2123        let executor1 = mock_executor("executor-id1".to_string());
2124        let executor2 = mock_executor("executor-id2".to_string());
2125        let mut agg_graph = test_aggregation_plan(4).await;
2126        // Call revive to move the leaf Resolved stages to Running
2127        agg_graph.revive();
2128
2129        // Complete the Stage 1
2130        revive_graph_and_complete_next_stage(&mut agg_graph)?;
2131
2132        // 1st task in the Stage 2
2133        let task1 = agg_graph.pop_next_task(&executor2.id)?.unwrap();
2134        let task_status1 = mock_completed_task(task1, &executor2.id);
2135
2136        // 2nd task in the Stage 2, failed due to FetchPartitionError
2137        let task2 = agg_graph.pop_next_task(&executor2.id)?.unwrap();
2138        let task_status2 = mock_failed_task(
2139            task2,
2140            FailedTask {
2141                error: "FetchPartitionError".to_string(),
2142                retryable: false,
2143                count_to_failures: false,
2144                failed_reason: Some(failed_task::FailedReason::FetchPartitionError(
2145                    FetchPartitionError {
2146                        executor_id: executor1.id.clone(),
2147                        map_stage_id: 1,
2148                        map_partition_id: 0,
2149                    },
2150                )),
2151            },
2152        );
2153
2154        let mut running_task_count = 0;
2155        while let Some(_task) = agg_graph.pop_next_task(&executor2.id)? {
2156            running_task_count += 1;
2157        }
2158        assert_eq!(running_task_count, 2);
2159
2160        let stage_events = agg_graph.update_task_status(
2161            &executor2,
2162            vec![task_status1, task_status2],
2163            4,
2164            4,
2165        )?;
2166
2167        assert_eq!(stage_events.len(), 1);
2168        assert!(matches!(
2169            stage_events[0],
2170            QueryStageSchedulerEvent::CancelTasks(_)
2171        ));
2172
2173        // Stage 1 is running
2174        let running_stage = agg_graph.running_stages();
2175        assert_eq!(running_stage.len(), 1);
2176        assert_eq!(running_stage[0], 1);
2177        assert_eq!(agg_graph.available_tasks(), 2);
2178
2179        drain_tasks(&mut agg_graph)?;
2180        assert!(agg_graph.is_successful(), "Failed to complete agg plan");
2181        Ok(())
2182    }
2183
2184    #[tokio::test]
2185    async fn test_many_fetch_failures_in_one_stage() -> Result<()> {
2186        let executor1 = mock_executor("executor-id1".to_string());
2187        let executor2 = mock_executor("executor-id2".to_string());
2188        let executor3 = mock_executor("executor-id3".to_string());
2189        let mut agg_graph = test_two_aggregations_plan(8).await;
2190
2191        agg_graph.revive();
2192        assert_eq!(agg_graph.stage_count(), 3);
2193
2194        // Complete the Stage 1
2195        revive_graph_and_complete_next_stage(&mut agg_graph)?;
2196
2197        // Complete the Stage 2, 5 tasks run on executor_2 and 3 tasks run on executor_1
2198        for _i in 0..5 {
2199            if let Some(task) = agg_graph.pop_next_task(&executor2.id)? {
2200                let task_status = mock_completed_task(task, &executor2.id);
2201                agg_graph.update_task_status(&executor2, vec![task_status], 4, 4)?;
2202            }
2203        }
2204        assert_eq!(agg_graph.available_tasks(), 3);
2205        for _i in 0..3 {
2206            if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
2207                let task_status = mock_completed_task(task, &executor1.id);
2208                agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
2209            }
2210        }
2211
2212        // Run Stage 3, 6 tasks failed due to FetchPartitionError on different map partitions on executor_2
2213        let mut many_fetch_failure_status = vec![];
2214        for part in 2..8 {
2215            if let Some(task) = agg_graph.pop_next_task(&executor3.id)? {
2216                let task_status = mock_failed_task(
2217                    task,
2218                    FailedTask {
2219                        error: "FetchPartitionError".to_string(),
2220                        retryable: false,
2221                        count_to_failures: false,
2222                        failed_reason: Some(
2223                            failed_task::FailedReason::FetchPartitionError(
2224                                FetchPartitionError {
2225                                    executor_id: executor2.id.clone(),
2226                                    map_stage_id: 2,
2227                                    map_partition_id: part,
2228                                },
2229                            ),
2230                        ),
2231                    },
2232                );
2233                many_fetch_failure_status.push(task_status);
2234            }
2235        }
2236        assert_eq!(many_fetch_failure_status.len(), 6);
2237        agg_graph.update_task_status(&executor3, many_fetch_failure_status, 4, 4)?;
2238
2239        // The Running stage should be Stage 2 now
2240        let running_stage = agg_graph.running_stages();
2241        assert_eq!(running_stage.len(), 1);
2242        assert_eq!(running_stage[0], 2);
2243        assert_eq!(agg_graph.available_tasks(), 5);
2244
2245        drain_tasks(&mut agg_graph)?;
2246        assert!(agg_graph.is_successful(), "Failed to complete agg plan");
2247        Ok(())
2248    }
2249
2250    #[tokio::test]
2251    async fn test_many_consecutive_stage_fetch_failures() -> Result<()> {
2252        let executor1 = mock_executor("executor-id1".to_string());
2253        let executor2 = mock_executor("executor-id2".to_string());
2254        let mut agg_graph = test_aggregation_plan(4).await;
2255        // Call revive to move the leaf Resolved stages to Running
2256        agg_graph.revive();
2257
2258        for attempt in 0..6 {
2259            revive_graph_and_complete_next_stage(&mut agg_graph)?;
2260
2261            // 1rd task in the Stage 2, failed due to FetchPartitionError
2262            if let Some(task1) = agg_graph.pop_next_task(&executor2.id)? {
2263                let task_status1 = mock_failed_task(
2264                    task1.clone(),
2265                    FailedTask {
2266                        error: "FetchPartitionError".to_string(),
2267                        retryable: false,
2268                        count_to_failures: false,
2269                        failed_reason: Some(
2270                            failed_task::FailedReason::FetchPartitionError(
2271                                FetchPartitionError {
2272                                    executor_id: executor1.id.clone(),
2273                                    map_stage_id: 1,
2274                                    map_partition_id: 0,
2275                                },
2276                            ),
2277                        ),
2278                    },
2279                );
2280
2281                let stage_events =
2282                    agg_graph.update_task_status(&executor2, vec![task_status1], 4, 4)?;
2283
2284                if attempt < 3 {
2285                    // No JobRunningFailed stage events
2286                    assert_eq!(stage_events.len(), 0);
2287                    // Stage 1 is running
2288                    let running_stage = agg_graph.running_stages();
2289                    assert_eq!(running_stage.len(), 1);
2290                    assert_eq!(running_stage[0], 1);
2291                    assert_eq!(agg_graph.available_tasks(), 2);
2292                } else {
2293                    // Job is failed after exceeds the max_stage_failures
2294                    assert_eq!(stage_events.len(), 1);
2295                    assert!(matches!(
2296                        stage_events[0],
2297                        QueryStageSchedulerEvent::JobRunningFailed { .. }
2298                    ));
2299                    // Stage 2 is still running
2300                    let running_stage = agg_graph.running_stages();
2301                    assert_eq!(running_stage.len(), 1);
2302                    assert_eq!(running_stage[0], 2);
2303                }
2304            }
2305        }
2306
2307        drain_tasks(&mut agg_graph)?;
2308        assert!(!agg_graph.is_successful(), "Expect to fail the agg plan");
2309
2310        let failure_reason = format!("{:?}", agg_graph.status);
2311        assert!(failure_reason.contains("Job failed due to stage 2 failed: Stage 2 has failed 4 times, most recent failure reason"));
2312        assert!(failure_reason.contains("FetchPartitionError"));
2313
2314        Ok(())
2315    }
2316
2317    #[tokio::test]
2318    async fn test_long_delayed_fetch_failures() -> Result<()> {
2319        let executor1 = mock_executor("executor-id1".to_string());
2320        let executor2 = mock_executor("executor-id2".to_string());
2321        let executor3 = mock_executor("executor-id3".to_string());
2322        let mut agg_graph = test_two_aggregations_plan(8).await;
2323
2324        agg_graph.revive();
2325        assert_eq!(agg_graph.stage_count(), 3);
2326
2327        // Complete the Stage 1
2328        revive_graph_and_complete_next_stage(&mut agg_graph)?;
2329
2330        // Complete the Stage 2, 5 tasks run on executor_2, 2 tasks run on executor_1, 1 task runs on executor_3
2331        for _i in 0..5 {
2332            if let Some(task) = agg_graph.pop_next_task(&executor2.id)? {
2333                let task_status = mock_completed_task(task, &executor2.id);
2334                agg_graph.update_task_status(&executor2, vec![task_status], 4, 4)?;
2335            }
2336        }
2337        assert_eq!(agg_graph.available_tasks(), 3);
2338
2339        for _i in 0..2 {
2340            if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
2341                let task_status = mock_completed_task(task, &executor1.id);
2342                agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
2343            }
2344        }
2345
2346        if let Some(task) = agg_graph.pop_next_task(&executor3.id)? {
2347            let task_status = mock_completed_task(task, &executor3.id);
2348            agg_graph.update_task_status(&executor3, vec![task_status], 4, 4)?;
2349        }
2350        assert_eq!(agg_graph.available_tasks(), 0);
2351
2352        //Run Stage 3
2353        // 1st task scheduled
2354        let task_1 = agg_graph.pop_next_task(&executor3.id)?.unwrap();
2355        // 2nd task scheduled
2356        let task_2 = agg_graph.pop_next_task(&executor3.id)?.unwrap();
2357        // 3rd task scheduled
2358        let task_3 = agg_graph.pop_next_task(&executor3.id)?.unwrap();
2359        // 4th task scheduled
2360        let task_4 = agg_graph.pop_next_task(&executor3.id)?.unwrap();
2361        // 5th task scheduled
2362        let task_5 = agg_graph.pop_next_task(&executor3.id)?.unwrap();
2363
2364        // Stage 3, 1st task failed due to FetchPartitionError(executor2)
2365        let task_status_1 = mock_failed_task(
2366            task_1,
2367            FailedTask {
2368                error: "FetchPartitionError".to_string(),
2369                retryable: false,
2370                count_to_failures: false,
2371                failed_reason: Some(failed_task::FailedReason::FetchPartitionError(
2372                    FetchPartitionError {
2373                        executor_id: executor2.id.clone(),
2374                        map_stage_id: 2,
2375                        map_partition_id: 0,
2376                    },
2377                )),
2378            },
2379        );
2380        agg_graph.update_task_status(&executor3, vec![task_status_1], 4, 4)?;
2381
2382        // The Running stage is Stage 2 now
2383        let running_stage = agg_graph.running_stages();
2384        assert_eq!(running_stage.len(), 1);
2385        assert_eq!(running_stage[0], 2);
2386        assert_eq!(agg_graph.available_tasks(), 5);
2387
2388        // Stage 3, 2nd task failed due to FetchPartitionError(executor2)
2389        let task_status_2 = mock_failed_task(
2390            task_2,
2391            FailedTask {
2392                error: "FetchPartitionError".to_string(),
2393                retryable: false,
2394                count_to_failures: false,
2395                failed_reason: Some(failed_task::FailedReason::FetchPartitionError(
2396                    FetchPartitionError {
2397                        executor_id: executor2.id.clone(),
2398                        map_stage_id: 2,
2399                        map_partition_id: 1,
2400                    },
2401                )),
2402            },
2403        );
2404        // This task update should be ignored
2405        agg_graph.update_task_status(&executor3, vec![task_status_2], 4, 4)?;
2406        let running_stage = agg_graph.running_stages();
2407        assert_eq!(running_stage.len(), 1);
2408        assert_eq!(running_stage[0], 2);
2409        assert_eq!(agg_graph.available_tasks(), 5);
2410
2411        // Stage 3, 3rd task failed due to FetchPartitionError(executor1)
2412        let task_status_3 = mock_failed_task(
2413            task_3,
2414            FailedTask {
2415                error: "FetchPartitionError".to_string(),
2416                retryable: false,
2417                count_to_failures: false,
2418                failed_reason: Some(failed_task::FailedReason::FetchPartitionError(
2419                    FetchPartitionError {
2420                        executor_id: executor1.id.clone(),
2421                        map_stage_id: 2,
2422                        map_partition_id: 1,
2423                    },
2424                )),
2425            },
2426        );
2427        // This task update should be handled because it has a different failure reason
2428        agg_graph.update_task_status(&executor3, vec![task_status_3], 4, 4)?;
2429        // Running stage is still Stage 2, but available tasks changed to 7
2430        assert_eq!(running_stage.len(), 1);
2431        assert_eq!(running_stage[0], 2);
2432        assert_eq!(agg_graph.available_tasks(), 7);
2433
2434        // Finish 4 tasks in Stage 2, to make some progress
2435        for _i in 0..4 {
2436            if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
2437                let task_status = mock_completed_task(task, &executor1.id);
2438                agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
2439            }
2440        }
2441        assert_eq!(running_stage.len(), 1);
2442        assert_eq!(running_stage[0], 2);
2443        assert_eq!(agg_graph.available_tasks(), 3);
2444
2445        // Stage 3, 4th task failed due to FetchPartitionError(executor1)
2446        let task_status_4 = mock_failed_task(
2447            task_4,
2448            FailedTask {
2449                error: "FetchPartitionError".to_string(),
2450                retryable: false,
2451                count_to_failures: false,
2452                failed_reason: Some(failed_task::FailedReason::FetchPartitionError(
2453                    FetchPartitionError {
2454                        executor_id: executor1.id.clone(),
2455                        map_stage_id: 2,
2456                        map_partition_id: 1,
2457                    },
2458                )),
2459            },
2460        );
2461        // This task update should be ignored because the same failure reason is already handled
2462        agg_graph.update_task_status(&executor3, vec![task_status_4], 4, 4)?;
2463        let running_stage = agg_graph.running_stages();
2464        assert_eq!(running_stage.len(), 1);
2465        assert_eq!(running_stage[0], 2);
2466        assert_eq!(agg_graph.available_tasks(), 3);
2467
2468        // Finish the other 3 tasks in Stage 2
2469        for _i in 0..3 {
2470            if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
2471                let task_status = mock_completed_task(task, &executor1.id);
2472                agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
2473            }
2474        }
2475        assert_eq!(agg_graph.available_tasks(), 0);
2476
2477        // Stage 3, the very long delayed 5th task failed due to FetchPartitionError(executor3)
2478        // Although the failure reason is new, but this task should be ignored
2479        // Because its map stage's new attempt is finished and this stage's new attempt is running
2480        let task_status_5 = mock_failed_task(
2481            task_5,
2482            FailedTask {
2483                error: "FetchPartitionError".to_string(),
2484                retryable: false,
2485                count_to_failures: false,
2486                failed_reason: Some(failed_task::FailedReason::FetchPartitionError(
2487                    FetchPartitionError {
2488                        executor_id: executor3.id.clone(),
2489                        map_stage_id: 2,
2490                        map_partition_id: 1,
2491                    },
2492                )),
2493            },
2494        );
2495        agg_graph.update_task_status(&executor3, vec![task_status_5], 4, 4)?;
2496        // Stage 3's new attempt is running
2497        let running_stage = agg_graph.running_stages();
2498        assert_eq!(running_stage.len(), 1);
2499        assert_eq!(running_stage[0], 3);
2500        assert_eq!(agg_graph.available_tasks(), 8);
2501
2502        // There is one failed stage attempts: Stage 3. Stage 2 does not count to failed attempts
2503        assert_eq!(agg_graph.failed_stage_attempts.len(), 1);
2504        assert_eq!(
2505            agg_graph.failed_stage_attempts.get(&3).cloned(),
2506            Some(HashSet::from([0]))
2507        );
2508        drain_tasks(&mut agg_graph)?;
2509        assert!(agg_graph.is_successful(), "Failed to complete agg plan");
2510        // Failed stage attempts are cleaned
2511        assert_eq!(agg_graph.failed_stage_attempts.len(), 0);
2512
2513        Ok(())
2514    }
2515
2516    #[tokio::test]
2517    // This test case covers a race condition in delayed fetch failure handling:
2518    // TaskStatus of input stage's new attempt come together with the parent stage's delayed FetchFailure
2519    async fn test_long_delayed_fetch_failures_race_condition() -> Result<()> {
2520        let executor1 = mock_executor("executor-id1".to_string());
2521        let executor2 = mock_executor("executor-id2".to_string());
2522        let executor3 = mock_executor("executor-id3".to_string());
2523        let mut agg_graph = test_two_aggregations_plan(8).await;
2524
2525        agg_graph.revive();
2526        assert_eq!(agg_graph.stage_count(), 3);
2527
2528        // Complete the Stage 1
2529        revive_graph_and_complete_next_stage(&mut agg_graph)?;
2530
2531        // Complete the Stage 2, 5 tasks run on executor_2, 3 tasks run on executor_1
2532        for _i in 0..5 {
2533            if let Some(task) = agg_graph.pop_next_task(&executor2.id)? {
2534                let task_status = mock_completed_task(task, &executor2.id);
2535                agg_graph.update_task_status(&executor2, vec![task_status], 4, 4)?;
2536            }
2537        }
2538        assert_eq!(agg_graph.available_tasks(), 3);
2539
2540        for _i in 0..3 {
2541            if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
2542                let task_status = mock_completed_task(task, &executor1.id);
2543                agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
2544            }
2545        }
2546        assert_eq!(agg_graph.available_tasks(), 0);
2547
2548        // Run Stage 3
2549        // 1st task scheduled
2550        let task_1 = agg_graph.pop_next_task(&executor3.id)?.unwrap();
2551        // 2nd task scheduled
2552        let task_2 = agg_graph.pop_next_task(&executor3.id)?.unwrap();
2553
2554        // Stage 3, 1st task failed due to FetchPartitionError(executor2)
2555        let task_status_1 = mock_failed_task(
2556            task_1,
2557            FailedTask {
2558                error: "FetchPartitionError".to_string(),
2559                retryable: false,
2560                count_to_failures: false,
2561                failed_reason: Some(failed_task::FailedReason::FetchPartitionError(
2562                    FetchPartitionError {
2563                        executor_id: executor2.id.clone(),
2564                        map_stage_id: 2,
2565                        map_partition_id: 0,
2566                    },
2567                )),
2568            },
2569        );
2570        agg_graph.update_task_status(&executor3, vec![task_status_1], 4, 4)?;
2571
2572        // The Running stage is Stage 2 now
2573        let running_stage = agg_graph.running_stages();
2574        assert_eq!(running_stage.len(), 1);
2575        assert_eq!(running_stage[0], 2);
2576        assert_eq!(agg_graph.available_tasks(), 5);
2577
2578        // Complete the 5 tasks in Stage 2's new attempts
2579        let mut task_status_vec = vec![];
2580        for _i in 0..5 {
2581            if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
2582                task_status_vec.push(mock_completed_task(task, &executor1.id))
2583            }
2584        }
2585
2586        // Stage 3, 2nd task failed due to FetchPartitionError(executor1)
2587        let task_status_2 = mock_failed_task(
2588            task_2,
2589            FailedTask {
2590                error: "FetchPartitionError".to_string(),
2591                retryable: false,
2592                count_to_failures: false,
2593                failed_reason: Some(failed_task::FailedReason::FetchPartitionError(
2594                    FetchPartitionError {
2595                        executor_id: executor1.id.clone(),
2596                        map_stage_id: 2,
2597                        map_partition_id: 1,
2598                    },
2599                )),
2600            },
2601        );
2602        task_status_vec.push(task_status_2);
2603
2604        // TaskStatus of Stage 2 come together with Stage 3 delayed FetchFailure update.
2605        // The successful tasks from Stage 2 would try to succeed the Stage2 and the delayed fetch failure try to reset the TaskInfo
2606        agg_graph.update_task_status(&executor3, task_status_vec, 4, 4)?;
2607        //The Running stage is still Stage 2, 3 new pending tasks added due to FetchPartitionError(executor1)
2608        assert_eq!(running_stage.len(), 1);
2609        assert_eq!(running_stage[0], 2);
2610        assert_eq!(agg_graph.available_tasks(), 3);
2611
2612        drain_tasks(&mut agg_graph)?;
2613        assert!(agg_graph.is_successful(), "Failed to complete agg plan");
2614
2615        Ok(())
2616    }
2617
2618    #[tokio::test]
2619    async fn test_fetch_failures_in_different_stages() -> Result<()> {
2620        let executor1 = mock_executor("executor-id1".to_string());
2621        let executor2 = mock_executor("executor-id2".to_string());
2622        let executor3 = mock_executor("executor-id3".to_string());
2623        let mut agg_graph = test_two_aggregations_plan(8).await;
2624
2625        agg_graph.revive();
2626        assert_eq!(agg_graph.stage_count(), 3);
2627
2628        // Complete the Stage 1
2629        revive_graph_and_complete_next_stage(&mut agg_graph)?;
2630
2631        // Complete the Stage 2, 5 tasks run on executor_2, 3 tasks run on executor_1
2632        for _i in 0..5 {
2633            if let Some(task) = agg_graph.pop_next_task(&executor2.id)? {
2634                let task_status = mock_completed_task(task, &executor2.id);
2635                agg_graph.update_task_status(&executor2, vec![task_status], 4, 4)?;
2636            }
2637        }
2638        assert_eq!(agg_graph.available_tasks(), 3);
2639        for _i in 0..3 {
2640            if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
2641                let task_status = mock_completed_task(task, &executor1.id);
2642                agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
2643            }
2644        }
2645        assert_eq!(agg_graph.available_tasks(), 0);
2646
2647        // Run Stage 3
2648        // 1rd task in the Stage 3, failed due to FetchPartitionError(executor1)
2649        if let Some(task1) = agg_graph.pop_next_task(&executor3.id)? {
2650            let task_status1 = mock_failed_task(
2651                task1,
2652                FailedTask {
2653                    error: "FetchPartitionError".to_string(),
2654                    retryable: false,
2655                    count_to_failures: false,
2656                    failed_reason: Some(failed_task::FailedReason::FetchPartitionError(
2657                        FetchPartitionError {
2658                            executor_id: executor1.id.clone(),
2659                            map_stage_id: 2,
2660                            map_partition_id: 0,
2661                        },
2662                    )),
2663                },
2664            );
2665
2666            let _stage_events =
2667                agg_graph.update_task_status(&executor3, vec![task_status1], 4, 4)?;
2668        }
2669        // The Running stage is Stage 2 now
2670        let running_stage = agg_graph.running_stages();
2671        assert_eq!(running_stage.len(), 1);
2672        assert_eq!(running_stage[0], 2);
2673        assert_eq!(agg_graph.available_tasks(), 3);
2674
2675        // 1rd task in the Stage 2's new attempt, failed due to FetchPartitionError(executor1)
2676        if let Some(task1) = agg_graph.pop_next_task(&executor3.id)? {
2677            let task_status1 = mock_failed_task(
2678                task1,
2679                FailedTask {
2680                    error: "FetchPartitionError".to_string(),
2681                    retryable: false,
2682                    count_to_failures: false,
2683                    failed_reason: Some(failed_task::FailedReason::FetchPartitionError(
2684                        FetchPartitionError {
2685                            executor_id: executor1.id.clone(),
2686                            map_stage_id: 1,
2687                            map_partition_id: 0,
2688                        },
2689                    )),
2690                },
2691            );
2692            let _stage_events =
2693                agg_graph.update_task_status(&executor3, vec![task_status1], 4, 4)?;
2694        }
2695        // The Running stage is Stage 1 now
2696        let running_stage = agg_graph.running_stages();
2697        assert_eq!(running_stage.len(), 1);
2698        assert_eq!(running_stage[0], 1);
2699        assert_eq!(agg_graph.available_tasks(), 2);
2700
2701        // There are two failed stage attempts: Stage 2 and Stage 3
2702        assert_eq!(agg_graph.failed_stage_attempts.len(), 2);
2703        assert_eq!(
2704            agg_graph.failed_stage_attempts.get(&2).cloned(),
2705            Some(HashSet::from([1]))
2706        );
2707        assert_eq!(
2708            agg_graph.failed_stage_attempts.get(&3).cloned(),
2709            Some(HashSet::from([0]))
2710        );
2711
2712        drain_tasks(&mut agg_graph)?;
2713        assert!(agg_graph.is_successful(), "Failed to complete agg plan");
2714        assert_eq!(agg_graph.failed_stage_attempts.len(), 0);
2715        Ok(())
2716    }
2717
2718    #[tokio::test]
2719    async fn test_fetch_failure_with_normal_task_failure() -> Result<()> {
2720        let executor1 = mock_executor("executor-id1".to_string());
2721        let executor2 = mock_executor("executor-id2".to_string());
2722        let mut agg_graph = test_aggregation_plan(4).await;
2723
2724        // Complete the Stage 1
2725        revive_graph_and_complete_next_stage(&mut agg_graph)?;
2726
2727        // 1st task in the Stage 2
2728        let task1 = agg_graph.pop_next_task(&executor2.id)?.unwrap();
2729        let task_status1 = mock_completed_task(task1, &executor2.id);
2730
2731        // 2nd task in the Stage 2, failed due to FetchPartitionError
2732        let task2 = agg_graph.pop_next_task(&executor2.id)?.unwrap();
2733        let task_status2 = mock_failed_task(
2734            task2,
2735            FailedTask {
2736                error: "FetchPartitionError".to_string(),
2737                retryable: false,
2738                count_to_failures: false,
2739                failed_reason: Some(failed_task::FailedReason::FetchPartitionError(
2740                    FetchPartitionError {
2741                        executor_id: executor1.id.clone(),
2742                        map_stage_id: 1,
2743                        map_partition_id: 0,
2744                    },
2745                )),
2746            },
2747        );
2748
2749        // 3rd task in the Stage 2, failed due to ExecutionError
2750        let task3 = agg_graph.pop_next_task(&executor2.id)?.unwrap();
2751        let task_status3 = mock_failed_task(
2752            task3,
2753            FailedTask {
2754                error: "ExecutionError".to_string(),
2755                retryable: false,
2756                count_to_failures: false,
2757                failed_reason: Some(failed_task::FailedReason::ExecutionError(
2758                    ExecutionError {},
2759                )),
2760            },
2761        );
2762
2763        let stage_events = agg_graph.update_task_status(
2764            &executor2,
2765            vec![task_status1, task_status2, task_status3],
2766            4,
2767            4,
2768        )?;
2769
2770        assert_eq!(stage_events.len(), 1);
2771        assert!(matches!(
2772            stage_events[0],
2773            QueryStageSchedulerEvent::JobRunningFailed { .. }
2774        ));
2775
2776        drain_tasks(&mut agg_graph)?;
2777        assert!(!agg_graph.is_successful(), "Expect to fail the agg plan");
2778
2779        let failure_reason = format!("{:?}", agg_graph.status);
2780        assert!(failure_reason.contains("Job failed due to stage 2 failed"));
2781        assert!(failure_reason.contains("ExecutionError"));
2782
2783        Ok(())
2784    }
2785
2786    // #[tokio::test]
2787    // async fn test_shuffle_files_should_cleaned_after_fetch_failure() -> Result<()> {
2788    //     todo!()
2789    // }
2790
2791    fn drain_tasks(graph: &mut ExecutionGraph) -> Result<()> {
2792        let executor = mock_executor("executor-id1".to_string());
2793        while let Some(task) = graph.pop_next_task(&executor.id)? {
2794            let task_status = mock_completed_task(task, &executor.id);
2795            graph.update_task_status(&executor, vec![task_status], 1, 1)?;
2796        }
2797
2798        Ok(())
2799    }
2800}