Skip to main content

datafusion_distributed/
stage.rs

1use crate::execution_plans::{DistributedExec, NetworkCoalesceExec};
2use crate::metrics::DISTRIBUTED_DATAFUSION_TASK_ID_LABEL;
3use crate::{NetworkShuffleExec, PartitionIsolatorExec};
4use datafusion::common::plan_err;
5use datafusion::common::{HashMap, config_err};
6use datafusion::error::Result;
7use datafusion::execution::TaskContext;
8use datafusion::physical_plan::display::DisplayableExecutionPlan;
9use datafusion::physical_plan::metrics::{Label, Metric, MetricsSet};
10use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties, displayable};
11use itertools::Either;
12use std::collections::VecDeque;
13use std::sync::Arc;
14use url::Url;
15use uuid::Uuid;
16
17/// A unit of isolation for a portion of a physical execution plan
18/// that can be executed independently and across a network boundary.
19/// It implements [`ExecutionPlan`] and can be executed to produce a
20/// stream of record batches.
21///
22/// If a stage has input stages, then those input stages will be executed on remote resources
23/// and will be provided the remainder of the stage tree.
24///
25/// For example, if our stage tree looks like this:
26///
27/// ```text
28///                       ┌─────────┐
29///                       │ stage 1 │
30///                       └───┬─────┘
31///                           │
32///                    ┌──────┴────────┐
33///               ┌────┴────┐     ┌────┴────┐
34///               │ stage 2 │     │ stage 3 │
35///               └────┬────┘     └─────────┘
36///                    │
37///             ┌──────┴────────┐
38///        ┌────┴────┐     ┌────┴────┐
39///        │ stage 4 │     │ Stage 5 │
40///        └─────────┘     └─────────┘
41///
42/// ```
43///
44/// Then executing Stage 1 will run its plan locally. Stage 1 has two inputs, Stage 2 and Stage 3. We
45/// know these will execute on remote resources. As such, the plan for Stage 1 must contain a
46/// [`NetworkShuffleExec`] node that will read the results of Stage 2 and Stage 3 and coalesce the
47/// results.
48///
49/// When Stage 1's [`NetworkShuffleExec`] node is executed, it makes an ArrowFlightRequest to the
50/// host assigned in the Stage. It provides the following Stage tree serialized in the body of the
51/// Arrow Flight Ticket:
52///
53/// ```text
54///               ┌─────────┐
55///               │ Stage 2 │
56///               └────┬────┘
57///                    │
58///             ┌──────┴────────┐
59///        ┌────┴────┐     ┌────┴────┐
60///        │ Stage 4 │     │ Stage 5 │
61///        └─────────┘     └─────────┘
62///
63/// ```
64///
65/// The receiving Worker will then execute Stage 2 and will repeat this process.
66///
67/// When Stage 4 is executed, it has no input tasks, so it is assumed that the plan included in that
68/// Stage can complete on its own; it's likely holding a leaf node in the overall physical plan and
69/// producing data from a [`DataSourceExec`].
70#[derive(Debug, Clone)]
71pub struct Stage {
72    /// Our query_id
73    pub(crate) query_id: Uuid,
74    /// Our stage number
75    pub(crate) num: usize,
76    /// The physical execution plan that this stage will execute. It will only be present if
77    /// accessing to it through the coordinating stage.
78    pub(crate) plan: Option<Arc<dyn ExecutionPlan>>,
79    /// Our tasks which tell us how finely grained to execute the partitions in
80    /// the plan
81    pub tasks: Vec<ExecutionTask>,
82}
83
84#[derive(Debug, Clone)]
85pub struct ExecutionTask {
86    /// The url of the worker that will execute this task.  A None value is interpreted as
87    /// unassigned.
88    pub(crate) url: Option<Url>,
89}
90
91#[derive(Debug, Clone, PartialEq)]
92pub struct DistributedTaskContext {
93    pub task_index: usize,
94    pub task_count: usize,
95}
96
97impl DistributedTaskContext {
98    pub fn from_ctx(ctx: &Arc<TaskContext>) -> Arc<Self> {
99        ctx.session_config()
100            .get_extension::<Self>()
101            .unwrap_or(Arc::new(DistributedTaskContext {
102                task_index: 0,
103                task_count: 1,
104            }))
105    }
106}
107
108impl Stage {
109    /// Creates a new `Stage` with the given plan and inputs. `ExecutionTasks` will be created for
110    /// each of the `n_tasks` specified tasks.
111    pub(crate) fn new(
112        query_id: Uuid,
113        num: usize,
114        plan: Arc<dyn ExecutionPlan>,
115        n_tasks: usize,
116    ) -> Self {
117        Self {
118            query_id,
119            num,
120            plan: Some(plan),
121            tasks: vec![ExecutionTask { url: None }; n_tasks],
122        }
123    }
124}
125
126use crate::{DistributedMetricsFormat, rewrite_distributed_plan_with_metrics};
127use crate::{NetworkBoundary, NetworkBoundaryExt};
128use datafusion::common::DataFusionError;
129use datafusion::physical_expr::Partitioning;
130/// Be able to display a nice tree for stages.
131///
132/// The challenge to doing this at the moment is that `TreeRenderVisitor`
133/// in [`datafusion::physical_plan::display`] is not public, and that it also
134/// is specific to an `ExecutionPlan` trait object, which we don't have.
135///
136/// TODO: try to upstream a change to make rendering of Trees (logical, physical, stages) against
137/// a generic trait rather than a specific trait object. This would allow us to
138/// use the same rendering code for all trees, including stages.
139///
140/// In the meantime, we can make a dummy ExecutionPlan that will let us render
141/// the Stage tree.
142use std::fmt::Write;
143
144/// explain_analyze renders an [ExecutionPlan] with metrics.
145pub fn explain_analyze(
146    executed: Arc<dyn ExecutionPlan>,
147    format: DistributedMetricsFormat,
148) -> Result<String, DataFusionError> {
149    match executed.as_any().downcast_ref::<DistributedExec>() {
150        None => Ok(DisplayableExecutionPlan::with_metrics(executed.as_ref())
151            .indent(true)
152            .to_string()),
153        Some(_) => {
154            let executed = rewrite_distributed_plan_with_metrics(executed.clone(), format)?;
155            Ok(display_plan_ascii(executed.as_ref(), true))
156        }
157    }
158}
159
160// Unicode box-drawing characters for creating borders and connections.
161const LTCORNER: &str = "┌"; // Left top corner
162const LDCORNER: &str = "└"; // Left bottom corner
163const VERTICAL: &str = "│"; // Vertical line
164const HORIZONTAL: &str = "─"; // Horizontal line
165pub fn display_plan_ascii(plan: &dyn ExecutionPlan, show_metrics: bool) -> String {
166    if let Some(plan) = plan.as_any().downcast_ref::<DistributedExec>() {
167        let mut f = String::new();
168        display_ascii(Either::Left(plan), 0, show_metrics, &mut f).unwrap();
169        f
170    } else {
171        match show_metrics {
172            true => DisplayableExecutionPlan::with_metrics(plan)
173                .indent(true)
174                .to_string(),
175            false => displayable(plan).indent(true).to_string(),
176        }
177    }
178}
179
180fn display_ascii(
181    stage: Either<&DistributedExec, &Stage>,
182    depth: usize,
183    show_metrics: bool,
184    f: &mut String,
185) -> std::fmt::Result {
186    let plan = match stage {
187        Either::Left(distributed_exec) => distributed_exec.children().first().unwrap(),
188        Either::Right(stage) => {
189            let Some(plan) = &stage.plan else {
190                return write!(f, "StageExec: encoded input plan");
191            };
192            plan
193        }
194    };
195    match stage {
196        Either::Left(dist_exec) => {
197            writeln!(
198                f,
199                "{}{}{} DistributedExec {} {}{}",
200                "  ".repeat(depth),
201                LTCORNER,
202                HORIZONTAL.repeat(5),
203                HORIZONTAL.repeat(2),
204                format_tasks_for_stage(1, plan),
205                if show_metrics {
206                    format_metrics_by_task(&dist_exec.metrics().unwrap_or_default())
207                } else {
208                    "".into()
209                }
210            )?;
211        }
212        Either::Right(stage) => {
213            writeln!(
214                f,
215                "{}{}{} Stage {} {} {}",
216                "  ".repeat(depth),
217                LTCORNER,
218                HORIZONTAL.repeat(5),
219                stage.num,
220                HORIZONTAL.repeat(2),
221                format_tasks_for_stage(stage.tasks.len(), plan)
222            )?;
223        }
224    }
225
226    let mut plan_str = String::new();
227    display_inner_ascii(plan, 0, show_metrics, &mut plan_str)?;
228    let plan_str = plan_str
229        .split('\n')
230        .filter(|v| !v.is_empty())
231        .collect::<Vec<_>>()
232        .join(&format!("\n{}{}", "  ".repeat(depth), VERTICAL));
233    writeln!(f, "{}{}{}", "  ".repeat(depth), VERTICAL, plan_str)?;
234    writeln!(
235        f,
236        "{}{}{}",
237        "  ".repeat(depth),
238        LDCORNER,
239        HORIZONTAL.repeat(50)
240    )?;
241    for input_stage in find_input_stages(plan.as_ref()) {
242        display_ascii(Either::Right(input_stage), depth + 1, show_metrics, f)?;
243    }
244    Ok(())
245}
246
247fn display_inner_ascii(
248    plan: &Arc<dyn ExecutionPlan>,
249    indent: usize,
250    show_metrics: bool,
251    f: &mut String,
252) -> std::fmt::Result {
253    let metrics_str = if show_metrics {
254        if let Some(metrics) = plan.metrics() {
255            let formatted = format_metrics_by_task(&metrics);
256            if formatted.is_empty() {
257                ", metrics=[]".to_string()
258            } else {
259                format!(", metrics=[{formatted}]")
260            }
261        } else {
262            ", metrics=[]".to_string()
263        }
264    } else {
265        String::new()
266    };
267
268    let node_str = displayable(plan.as_ref()).one_line().to_string();
269    writeln!(
270        f,
271        "{} {}{metrics_str}",
272        " ".repeat(indent),
273        node_str.trim_end() // remove trailing newline
274    )?;
275
276    if plan.is_network_boundary() {
277        return Ok(());
278    }
279
280    for child in plan.children() {
281        display_inner_ascii(child, indent + 2, show_metrics, f)?;
282    }
283    Ok(())
284}
285
286/// Aggregates metrics by (name, task_id), preserving the [DISTRIBUTED_DATAFUSION_TASK_ID_LABEL]
287/// only. Metrics without a task_id label (ie. non distributed metrics) are aggregated together.
288///
289/// For a non-distributed plan, this is equivalent to [MetricsSet::aggregate_by_name] since there
290/// will be no task ids. For a distributed plan, it's expected that the metrics rewriter populated
291/// task id labels in all metrics.
292fn aggregate_by_task_id(metrics: &MetricsSet) -> MetricsSet {
293    // Key: (metric_name, Option<task_id>)
294    let mut map: HashMap<(String, Option<String>), Metric> = HashMap::new();
295
296    for metric in metrics.iter() {
297        let name = metric.value().name().to_string();
298        let task_id = metric
299            .labels()
300            .iter()
301            .find(|l| l.name() == DISTRIBUTED_DATAFUSION_TASK_ID_LABEL)
302            .map(|l| l.value().to_string());
303
304        let key = (name, task_id.clone());
305
306        map.entry(key)
307            .and_modify(|accum| {
308                accum.value_mut().aggregate(metric.value());
309            })
310            .or_insert_with(|| {
311                let labels = task_id
312                    .map(|id| vec![Label::new(DISTRIBUTED_DATAFUSION_TASK_ID_LABEL, id)])
313                    .unwrap_or_default();
314                let mut accum = Metric::new_with_labels(
315                    metric.value().new_empty(),
316                    None, // no partition
317                    labels,
318                );
319                accum.value_mut().aggregate(metric.value());
320                accum
321            });
322    }
323
324    let mut result = MetricsSet::new();
325    for (_, metric) in map {
326        result.push(Arc::new(metric));
327    }
328    result
329}
330
331/// Sorts metrics by display priority, then name, then by task_id (numerically).
332///
333/// For a non-distributed plan, this is equivalent to [MetricsSet::sorted_for_display] since there
334/// will be no task ids. For a distributed plan, it's expected that the metrics rewriter populated
335/// task id labels in all metrics.
336fn sorted_for_display_by_task_id(metrics: MetricsSet) -> MetricsSet {
337    let mut vec: Vec<Arc<Metric>> = metrics.iter().cloned().collect();
338    vec.sort_unstable_by_key(|metric| {
339        let task_id = metric
340            .labels()
341            .iter()
342            .find(|l| l.name() == DISTRIBUTED_DATAFUSION_TASK_ID_LABEL)
343            .and_then(|l| l.value().parse::<u64>().ok());
344        (
345            metric.value().display_sort_key(),
346            metric.value().name().to_owned(),
347            task_id,
348        )
349    });
350    let mut result = MetricsSet::new();
351    for m in vec {
352        result.push(m);
353    }
354    result
355}
356
357/// Formats metrics as "{metric_name}_{task_id}={value}, {metric_name}_{task_id}={value}"
358/// e.g., "output_rows_0=100, output_rows_1=150, elapsed_compute_0=50ns, elapsed_compute_1=100ns"
359///
360/// For a non-distributed plan, this is equivalent to using [ShowMetrics::Aggregated] /
361/// [DisplayableExecutionPlan::with_metrics] which aggregates, sorts, removes timestamps, and finally formats
362/// the metrics.
363///
364/// See
365/// https://github.com/apache/datafusion/blob/b463a9f9e3c9603eb2db7113125fea3a1b7f5455/datafusion/physical-plan/src/display.rs#L421.
366fn format_metrics_by_task(metrics: &MetricsSet) -> String {
367    let aggregated = aggregate_by_task_id(metrics);
368    let sorted = sorted_for_display_by_task_id(aggregated).timestamps_removed();
369
370    sorted
371        .iter()
372        .map(|m| {
373            let name = m.value().name();
374            let task_id = m
375                .labels()
376                .iter()
377                .find(|l| l.name() == DISTRIBUTED_DATAFUSION_TASK_ID_LABEL)
378                .map(|l| l.value());
379
380            match task_id {
381                Some(id) => format!("{name}_{id}={}", m.value()),
382                None => format!("{name}={}", m.value()),
383            }
384        })
385        .collect::<Vec<_>>()
386        .join(", ")
387}
388
389fn format_tasks_for_stage(n_tasks: usize, head: &Arc<dyn ExecutionPlan>) -> String {
390    let partitioning = head.properties().output_partitioning();
391    let input_partitions = partitioning.partition_count();
392    let hash_shuffle = matches!(partitioning, Partitioning::Hash(_, _));
393    let mut result = "Tasks: ".to_string();
394    let mut off = 0;
395    for i in 0..n_tasks {
396        result += &format!("t{i}:[");
397        let end = off + input_partitions - 1;
398        if input_partitions == 1 {
399            result += &format!("p{off}");
400        } else {
401            result += &format!("p{off}..p{end}");
402        }
403        result += "] ";
404        off += if hash_shuffle { 0 } else { input_partitions }
405    }
406    result
407}
408
409// num_colors must agree with the colorscheme selected from
410// https://graphviz.org/doc/info/colors.html
411const NUM_COLORS: usize = 6;
412const COLOR_SCHEME: &str = "spectral6";
413
414/// This will render a regular or distributed datafusion plan as
415/// Graphviz dot format.
416/// You can view them on https://vis-js.com
417///
418/// Or it is often useful to experiment with plan output using
419/// https://datafusion-fiddle.vercel.app/
420pub fn display_plan_graphviz(plan: Arc<dyn ExecutionPlan>) -> Result<String> {
421    let mut f = String::new();
422
423    writeln!(
424        f,
425        "digraph G {{
426  rankdir=BT
427  edge[colorscheme={COLOR_SCHEME}, penwidth=2.0]
428  splines=false
429"
430    )?;
431
432    if plan.as_any().is::<DistributedExec>() {
433        let mut max_num = 0;
434        let mut all_stages = find_all_stages(&plan)
435            .into_iter()
436            .inspect(|v| max_num = max_num.max(v.num))
437            .collect::<Vec<_>>();
438        let head_stage = Stage {
439            query_id: Default::default(),
440            num: max_num + 1,
441            plan: Some(plan.clone()),
442            tasks: vec![ExecutionTask { url: None }],
443        };
444        all_stages.insert(0, &head_stage);
445
446        // draw all tasks first
447        for stage in &all_stages {
448            for i in 0..stage.tasks.iter().len() {
449                let p = display_single_task(stage, i)?;
450                writeln!(f, "{p}")?;
451            }
452        }
453        // now draw edges between the tasks
454        for stage in &all_stages {
455            let Some(plan) = &stage.plan else { continue };
456            for input_stage in find_input_stages(plan.as_ref()) {
457                for task_i in 0..stage.tasks.len() {
458                    for input_task_i in 0..input_stage.tasks.len() {
459                        let edges =
460                            display_inter_task_edges(stage, task_i, input_stage, input_task_i)?;
461                        writeln!(
462                            f,
463                            "// edges from child stage {} task {} to stage {} task {}\n {}",
464                            input_stage.num, input_task_i, stage.num, task_i, edges
465                        )?;
466                    }
467                }
468            }
469        }
470    } else {
471        // single plan, not a stage tree
472        writeln!(f, "node[shape=none]")?;
473        let p = display_plan(&plan, 0, 1, 0)?;
474        writeln!(f, "{p}")?;
475    }
476
477    writeln!(f, "}}")?;
478
479    Ok(f)
480}
481
482fn display_single_task(stage: &Stage, task_i: usize) -> Result<String> {
483    let Some(plan) = &stage.plan else {
484        return config_err!("plan not present");
485    };
486    let partition_group =
487        build_partition_group(task_i, plan.output_partitioning().partition_count());
488
489    let mut f = String::new();
490    writeln!(
491        f,
492        "
493  subgraph \"cluster_stage_{}_task_{}_margin\" {{
494    style=invis
495    margin=20.0
496  subgraph \"cluster_stage_{}_task_{}\" {{
497    color=blue
498    style=dotted
499    label = \"Stage {} Task {} Partitions {}\"
500    labeljust=r
501    labelloc=b
502
503    node[shape=none]
504
505",
506        stage.num,
507        task_i,
508        stage.num,
509        task_i,
510        stage.num,
511        task_i,
512        format_pg(&partition_group)
513    )?;
514
515    writeln!(
516        f,
517        "{}",
518        display_plan(plan, task_i, stage.tasks.len(), stage.num)?
519    )?;
520    writeln!(f, "  }}")?;
521    writeln!(f, "  }}")?;
522
523    Ok(f)
524}
525
526fn display_plan(
527    plan: &Arc<dyn ExecutionPlan>,
528    task_i: usize,
529    n_tasks: usize,
530    stage_num: usize,
531) -> Result<String> {
532    // draw all plans
533    // we need to label the nodes including depth to uniquely identify them within this task
534    // the tree node API provides depth first traversal, but we need breadth to align with
535    // how we will draw edges below, so we'll do that.
536    let mut queue = VecDeque::from([plan]);
537    let mut node_index = 0;
538
539    let mut f = String::new();
540    while let Some(plan) = queue.pop_front() {
541        node_index += 1;
542        let p = display_single_plan(plan.as_ref(), stage_num, task_i, node_index)?;
543        writeln!(f, "{p}")?;
544
545        if plan.is_network_boundary() {
546            continue;
547        }
548        for child in plan.children().iter() {
549            queue.push_back(child);
550        }
551    }
552
553    // draw edges between the plan nodes
554    type PlanWithParent<'a> = (
555        &'a Arc<dyn ExecutionPlan>,
556        Option<&'a Arc<dyn ExecutionPlan>>,
557        usize,
558    );
559    let mut queue: VecDeque<PlanWithParent> = VecDeque::from([(plan, None, 0usize)]);
560    let mut isolator_partition_group = None;
561    node_index = 0;
562    while let Some((plan, maybe_parent, parent_idx)) = queue.pop_front() {
563        node_index += 1;
564        if let Some(node) = plan.as_any().downcast_ref::<PartitionIsolatorExec>() {
565            isolator_partition_group = Some(PartitionIsolatorExec::partition_group(
566                node.input.output_partitioning().partition_count(),
567                task_i,
568                n_tasks,
569            ));
570        }
571        if let Some(parent) = maybe_parent {
572            let output_partitions = plan.output_partitioning().partition_count();
573
574            for i in 0..output_partitions {
575                let mut style = "";
576                if plan.as_any().is::<PartitionIsolatorExec>() {
577                    if i >= isolator_partition_group.as_ref().map_or(0, |v| v.len()) {
578                        style = "[style=dotted, label=empty]";
579                    }
580                } else if let Some(partition_group) = &isolator_partition_group
581                    && !partition_group.contains(&i)
582                {
583                    style = "[style=invis]";
584                }
585
586                writeln!(
587                    f,
588                    "  {}_{}_{}_{}:t{}:n -> {}_{}_{}_{}:b{}:s {}[color={}]",
589                    plan.name(),
590                    stage_num,
591                    task_i,
592                    node_index,
593                    i,
594                    parent.name(),
595                    stage_num,
596                    task_i,
597                    parent_idx,
598                    i,
599                    style,
600                    i % NUM_COLORS + 1
601                )?;
602            }
603        }
604
605        if plan.as_ref().is_network_boundary() {
606            continue;
607        }
608
609        for child in plan.children() {
610            queue.push_back((child, Some(plan), node_index));
611        }
612    }
613    Ok(f)
614}
615
616/// We want to display a single plan as a three row table with the top and bottom being
617/// graphvis ports.
618///
619/// We accept an index to make the node name unique in the graphviz output within
620/// a plan at the same depth
621///
622/// An example of such a node would be:
623///
624/// ```text
625///       NetworkShuffleExec [label=<
626///     <TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">
627///         <TR>
628///             <TD CELLBORDER="0">
629///                 <TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0">
630///                     <TR>
631///                         <TD PORT="t1"></TD>
632///                         <TD PORT="t2"></TD>
633///                     </TR>
634///                 </TABLE>
635///             </TD>
636///         </TR>
637///         <TR>
638///             <TD BORDER="0" CELLPADDING="0" CELLSPACING="0">
639///                 <TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0">
640///                     <TR>
641///                         <TD>NetworkShuffleExec</TD>
642///                     </TR>
643///                 </TABLE>
644///             </TD>
645///         </TR>
646///         <TR>
647///             <TD CELLBORDER="0">
648///                 <TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0">
649///                     <TR>
650///                         <TD PORT="b1"></TD>
651///                         <TD PORT="b2"></TD>
652///                     </TR>
653///                 </TABLE>
654///             </TD>
655///         </TR>
656///     </TABLE>
657/// >];
658/// ```
659pub fn display_single_plan(
660    plan: &(dyn ExecutionPlan + 'static),
661    stage_num: usize,
662    task_i: usize,
663    node_index: usize,
664) -> Result<String> {
665    let mut f = String::new();
666    let output_partitions = plan.output_partitioning().partition_count();
667    let input_partitions = if plan.is_network_boundary() {
668        output_partitions
669    } else if let Some(child) = plan.children().first() {
670        child.output_partitioning().partition_count()
671    } else {
672        1
673    };
674
675    writeln!(
676        f,
677        "
678    {}_{}_{}_{} [label=<
679    <TABLE BORDER='0' CELLBORDER='0' CELLSPACING='0' CELLPADDING='0'>
680        <TR>
681            <TD CELLBORDER='0'>
682                <TABLE BORDER='0' CELLBORDER='1' CELLSPACING='0'>
683                    <TR>",
684        plan.name(),
685        stage_num,
686        task_i,
687        node_index
688    )?;
689
690    for i in 0..output_partitions {
691        writeln!(f, "                        <TD PORT='t{i}'></TD>")?;
692    }
693
694    writeln!(
695        f,
696        "                   </TR>
697                </TABLE>
698            </TD>
699        </TR>
700        <TR>
701            <TD BORDER='0' CELLPADDING='0' CELLSPACING='0'>
702                <TABLE BORDER='0' CELLBORDER='1' CELLSPACING='0'>
703                    <TR>
704                        <TD>{}</TD>
705                    </TR>
706                </TABLE>
707            </TD>
708        </TR>
709        <TR>
710            <TD CELLBORDER='0'>
711                <TABLE BORDER='0' CELLBORDER='1' CELLSPACING='0'>
712                    <TR>",
713        plan.name()
714    )?;
715
716    for i in 0..input_partitions {
717        writeln!(f, "                        <TD PORT='b{i}'></TD>")?;
718    }
719
720    writeln!(
721        f,
722        "                   </TR>
723                </TABLE>
724            </TD>
725        </TR>
726    </TABLE>
727  >];
728"
729    )?;
730    Ok(f)
731}
732
733fn display_inter_task_edges(
734    stage: &Stage,
735    task_i: usize,
736    input_stage: &Stage,
737    input_task_i: usize,
738) -> Result<String> {
739    let Some(plan) = &stage.plan else {
740        return plan_err!("The inner plan of a stage was encoded.");
741    };
742    let Some(input_plan) = &input_stage.plan else {
743        return plan_err!("The inner plan of a stage was encoded.");
744    };
745    let mut f = String::new();
746
747    let mut queue = VecDeque::from([plan]);
748    let mut index = 0;
749    while let Some(plan) = queue.pop_front() {
750        index += 1;
751        if let Some(node) = plan.as_any().downcast_ref::<NetworkShuffleExec>() {
752            if node.input_stage().num != input_stage.num {
753                continue;
754            }
755            // draw the edges to this node pulling data up from its child
756            let output_partitions = plan.output_partitioning().partition_count();
757            for p in 0..output_partitions {
758                writeln!(
759                    f,
760                    "  {}_{}_{}_{}:t{}:n -> {}_{}_{}_{}:b{}:s [color={}]",
761                    input_plan.name(),
762                    input_stage.num,
763                    input_task_i,
764                    1, // the repartition exec is always the first node in the plan
765                    p + (task_i * output_partitions),
766                    plan.name(),
767                    stage.num,
768                    task_i,
769                    index,
770                    p,
771                    p % NUM_COLORS + 1
772                )?;
773            }
774            continue;
775        } else if let Some(node) = plan.as_any().downcast_ref::<NetworkCoalesceExec>() {
776            if node.input_stage().num != input_stage.num {
777                continue;
778            }
779            // draw the edges to this node pulling data up from its child
780            let output_partitions = plan.output_partitioning().partition_count();
781            let input_partitions_per_task = output_partitions / input_stage.tasks.len();
782            for p in 0..input_partitions_per_task {
783                writeln!(
784                    f,
785                    "  {}_{}_{}_{}:t{}:n -> {}_{}_{}_{}:b{}:s [color={}]",
786                    input_plan.name(),
787                    input_stage.num,
788                    input_task_i,
789                    1, // the repartition exec is always the first node in the plan
790                    p,
791                    plan.name(),
792                    stage.num,
793                    task_i,
794                    index,
795                    p + (input_task_i * input_partitions_per_task),
796                    p % NUM_COLORS + 1
797                )?;
798            }
799            continue;
800        }
801
802        for child in plan.children() {
803            queue.push_back(child);
804        }
805    }
806
807    Ok(f)
808}
809
810fn format_pg(partition_group: &[usize]) -> String {
811    partition_group
812        .iter()
813        .map(|pg| format!("{pg}"))
814        .collect::<Vec<_>>()
815        .join("_")
816}
817
818fn build_partition_group(task_i: usize, partitions: usize) -> Vec<usize> {
819    ((task_i * partitions)..((task_i + 1) * partitions)).collect::<Vec<_>>()
820}
821
822fn find_input_stages(plan: &dyn ExecutionPlan) -> Vec<&Stage> {
823    let mut result = vec![];
824    for child in plan.children() {
825        if let Some(plan) = child.as_network_boundary() {
826            result.push(plan.input_stage());
827        } else {
828            result.extend(find_input_stages(child.as_ref()));
829        }
830    }
831    result
832}
833
834pub(crate) fn find_all_stages(plan: &Arc<dyn ExecutionPlan>) -> Vec<&Stage> {
835    let mut result = vec![];
836    if let Some(plan) = plan.as_network_boundary() {
837        result.push(plan.input_stage());
838    }
839    for child in plan.children() {
840        result.extend(find_all_stages(child));
841    }
842    result
843}