kapot_scheduler/state/
execution_graph_dot.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Utilities for producing dot diagrams from execution graphs
19
20use crate::state::execution_graph::ExecutionGraph;
21use kapot_core::execution_plans::{
22    ShuffleReaderExec, ShuffleWriterExec, UnresolvedShuffleExec,
23};
24use datafusion::datasource::listing::PartitionedFile;
25use datafusion::datasource::physical_plan::{
26    AvroExec, CsvExec, FileScanConfig, NdJsonExec, ParquetExec,
27};
28use datafusion::physical_plan::aggregates::AggregateExec;
29use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
30use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
31use datafusion::physical_plan::filter::FilterExec;
32use datafusion::physical_plan::joins::CrossJoinExec;
33use datafusion::physical_plan::joins::HashJoinExec;
34use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
35use datafusion::physical_plan::memory::MemoryExec;
36use datafusion::physical_plan::projection::ProjectionExec;
37use datafusion::physical_plan::repartition::RepartitionExec;
38use datafusion::physical_plan::sorts::sort::SortExec;
39use datafusion::physical_plan::union::UnionExec;
40use datafusion::physical_plan::{ExecutionPlan, Partitioning, PhysicalExpr};
41use log::debug;
42use object_store::path::Path;
43use std::collections::HashMap;
44use std::fmt::{self, Write};
45use std::sync::Arc;
46
47/// Utility for producing dot diagrams from execution graphs
48pub struct ExecutionGraphDot<'a> {
49    graph: &'a ExecutionGraph,
50}
51
52impl<'a> ExecutionGraphDot<'a> {
53    /// Create a DOT graph from the provided ExecutionGraph
54    pub fn generate(graph: &'a ExecutionGraph) -> Result<String, fmt::Error> {
55        let mut dot = Self { graph };
56        dot._generate()
57    }
58
59    /// Create a DOT graph for one query stage from the provided ExecutionGraph
60    pub fn generate_for_query_stage(
61        graph: &ExecutionGraph,
62        stage_id: usize,
63    ) -> Result<String, fmt::Error> {
64        if let Some(stage) = graph.stages().get(&stage_id) {
65            let mut dot = String::new();
66            writeln!(&mut dot, "digraph G {{")?;
67            let stage_name = format!("stage_{stage_id}");
68            write_stage_plan(&mut dot, &stage_name, stage.plan(), 0)?;
69            writeln!(&mut dot, "}}")?;
70            Ok(dot)
71        } else {
72            Err(fmt::Error)
73        }
74    }
75
76    fn _generate(&mut self) -> Result<String, fmt::Error> {
77        // sort the stages by key for deterministic output for tests
78        let stages = self.graph.stages();
79        let mut stage_ids: Vec<usize> = stages.keys().cloned().collect();
80        stage_ids.sort();
81
82        let mut dot = String::new();
83
84        writeln!(&mut dot, "digraph G {{")?;
85
86        let mut cluster = 0;
87        let mut stage_meta = vec![];
88
89        #[allow(clippy::explicit_counter_loop)]
90        for id in &stage_ids {
91            let stage = stages.get(id).unwrap(); // safe unwrap
92            let stage_name = format!("stage_{id}");
93            writeln!(&mut dot, "\tsubgraph cluster{cluster} {{")?;
94            writeln!(
95                &mut dot,
96                "\t\tlabel = \"Stage {} [{}]\";",
97                id,
98                stage.variant_name()
99            )?;
100            stage_meta.push(write_stage_plan(&mut dot, &stage_name, stage.plan(), 0)?);
101            cluster += 1;
102            writeln!(&mut dot, "\t}}")?; // end of subgraph
103        }
104
105        // write links between stages
106        for meta in &stage_meta {
107            let mut links = vec![];
108            for (reader_node, parent_stage_id) in &meta.readers {
109                // shuffle write node is always node zero
110                let parent_shuffle_write_node = format!("stage_{parent_stage_id}_0");
111                links.push(format!("{parent_shuffle_write_node} -> {reader_node}"));
112            }
113            // keep the order deterministic
114            links.sort();
115            for link in links {
116                writeln!(&mut dot, "\t{link}")?;
117            }
118        }
119
120        writeln!(&mut dot, "}}")?; // end of digraph
121
122        Ok(dot)
123    }
124}
125
126/// Write the query tree for a single stage and build metadata needed to later draw
127/// the links between the stages
128fn write_stage_plan(
129    f: &mut String,
130    prefix: &str,
131    plan: &dyn ExecutionPlan,
132    i: usize,
133) -> Result<StagePlanState, fmt::Error> {
134    let mut state = StagePlanState {
135        readers: HashMap::new(),
136    };
137    write_plan_recursive(f, prefix, plan, i, &mut state)?;
138    Ok(state)
139}
140
141fn write_plan_recursive(
142    f: &mut String,
143    prefix: &str,
144    plan: &dyn ExecutionPlan,
145    i: usize,
146    state: &mut StagePlanState,
147) -> Result<(), fmt::Error> {
148    let node_name = format!("{prefix}_{i}");
149    let display_name = get_operator_name(plan);
150
151    if let Some(reader) = plan.as_any().downcast_ref::<ShuffleReaderExec>() {
152        for part in &reader.partition {
153            for loc in part {
154                state
155                    .readers
156                    .insert(node_name.clone(), loc.partition_id.stage_id);
157            }
158        }
159    } else if let Some(reader) = plan.as_any().downcast_ref::<UnresolvedShuffleExec>() {
160        state.readers.insert(node_name.clone(), reader.stage_id);
161    }
162
163    let mut metrics_str = vec![];
164    if let Some(metrics) = plan.metrics() {
165        if let Some(x) = metrics.output_rows() {
166            metrics_str.push(format!("output_rows={x}"))
167        }
168        if let Some(x) = metrics.elapsed_compute() {
169            metrics_str.push(format!("elapsed_compute={x}"))
170        }
171    }
172    if metrics_str.is_empty() {
173        writeln!(f, "\t\t{node_name} [shape=box, label=\"{display_name}\"]")?;
174    } else {
175        writeln!(
176            f,
177            "\t\t{} [shape=box, label=\"{}
178{}\"]",
179            node_name,
180            display_name,
181            metrics_str.join(", ")
182        )?;
183    }
184
185    for (j, child) in plan.children().into_iter().enumerate() {
186        write_plan_recursive(f, &node_name, child.as_ref(), j, state)?;
187        // write link from child to parent
188        writeln!(f, "\t\t{node_name}_{j} -> {node_name}")?;
189    }
190
191    Ok(())
192}
193
194#[derive(Debug)]
195struct StagePlanState {
196    /// map from reader node name to parent stage id
197    readers: HashMap<String, usize>,
198}
199
200/// Make strings dot-friendly
201fn sanitize_dot_label(str: &str) -> String {
202    // TODO make max length configurable eventually
203    sanitize(str, Some(100))
204}
205
206/// Make strings dot-friendly
207fn sanitize(str: &str, max_len: Option<usize>) -> String {
208    let mut sanitized = String::new();
209    for ch in str.chars() {
210        match ch {
211            '"' => sanitized.push('`'),
212            ' ' | '_' | '+' | '-' | '*' | '/' | '(' | ')' | '[' | ']' | '{' | '}'
213            | '!' | '@' | '#' | '$' | '%' | '&' | '=' | ':' | ';' | '\\' | '\'' | '.'
214            | ',' | '<' | '>' | '`' => sanitized.push(ch),
215            _ if ch.is_ascii_alphanumeric() || ch.is_ascii_whitespace() => {
216                sanitized.push(ch)
217            }
218            _ => sanitized.push('?'),
219        }
220    }
221    // truncate after translation because we know we only have ASCII chars at this point
222    // so the slice is safe (not splitting unicode character bytes)
223    if let Some(limit) = max_len {
224        if sanitized.len() > limit {
225            sanitized.truncate(limit);
226            return sanitized + " ...";
227        }
228    }
229    sanitized
230}
231
232fn get_operator_name(plan: &dyn ExecutionPlan) -> String {
233    if let Some(exec) = plan.as_any().downcast_ref::<FilterExec>() {
234        format!("Filter: {}", exec.predicate())
235    } else if let Some(exec) = plan.as_any().downcast_ref::<ProjectionExec>() {
236        let expr = exec
237            .expr()
238            .iter()
239            .map(|(e, _)| format!("{e}"))
240            .collect::<Vec<String>>()
241            .join(", ");
242        format!("Projection: {}", sanitize_dot_label(&expr))
243    } else if let Some(exec) = plan.as_any().downcast_ref::<SortExec>() {
244        let sort_expr = exec
245            .expr()
246            .iter()
247            .map(|e| {
248                let asc = if e.options.descending { " DESC" } else { "" };
249                let nulls = if e.options.nulls_first {
250                    " NULLS FIRST"
251                } else {
252                    ""
253                };
254                format!("{}{}{}", e.expr, asc, nulls)
255            })
256            .collect::<Vec<String>>()
257            .join(", ");
258        format!("Sort: {}", sanitize_dot_label(&sort_expr))
259    } else if let Some(exec) = plan.as_any().downcast_ref::<AggregateExec>() {
260        let group_exprs_with_alias = exec.group_expr().expr();
261        let group_expr = group_exprs_with_alias
262            .iter()
263            .map(|(e, _)| format!("{e}"))
264            .collect::<Vec<String>>()
265            .join(", ");
266        let aggr_expr = exec
267            .aggr_expr()
268            .iter()
269            .map(|e| e.name().to_owned())
270            .collect::<Vec<String>>()
271            .join(", ");
272        format!(
273            "Aggregate
274groupBy=[{}]
275aggr=[{}]",
276            sanitize_dot_label(&group_expr),
277            sanitize_dot_label(&aggr_expr)
278        )
279    } else if let Some(exec) = plan.as_any().downcast_ref::<CoalesceBatchesExec>() {
280        format!("CoalesceBatches [batchSize={}]", exec.target_batch_size())
281    } else if let Some(exec) = plan.as_any().downcast_ref::<CoalescePartitionsExec>() {
282        format!(
283            "CoalescePartitions [{}]",
284            format_partitioning(exec.properties().output_partitioning().clone())
285        )
286    } else if let Some(exec) = plan.as_any().downcast_ref::<RepartitionExec>() {
287        format!(
288            "RepartitionExec [{}]",
289            format_partitioning(exec.properties().output_partitioning().clone())
290        )
291    } else if let Some(exec) = plan.as_any().downcast_ref::<HashJoinExec>() {
292        let join_expr = exec
293            .on()
294            .iter()
295            .map(|(l, r)| format!("{l} = {r}"))
296            .collect::<Vec<String>>()
297            .join(" AND ");
298        let filter_expr = if let Some(f) = exec.filter() {
299            format!("{}", f.expression())
300        } else {
301            "".to_string()
302        };
303        format!(
304            "HashJoin
305join_expr={}
306filter_expr={}",
307            sanitize_dot_label(&join_expr),
308            sanitize_dot_label(&filter_expr)
309        )
310    } else if plan.as_any().downcast_ref::<CrossJoinExec>().is_some() {
311        "CrossJoin".to_string()
312    } else if plan.as_any().downcast_ref::<UnionExec>().is_some() {
313        "Union".to_string()
314    } else if let Some(exec) = plan.as_any().downcast_ref::<UnresolvedShuffleExec>() {
315        format!("UnresolvedShuffleExec [stage_id={}]", exec.stage_id)
316    } else if let Some(exec) = plan.as_any().downcast_ref::<ShuffleReaderExec>() {
317        format!("ShuffleReader [{} partitions]", exec.partition.len())
318    } else if let Some(exec) = plan.as_any().downcast_ref::<ShuffleWriterExec>() {
319        format!(
320            "ShuffleWriter [{} partitions]",
321            exec.input_partition_count()
322        )
323    } else if plan.as_any().downcast_ref::<MemoryExec>().is_some() {
324        "MemoryExec".to_string()
325    } else if let Some(exec) = plan.as_any().downcast_ref::<CsvExec>() {
326        let parts = exec.properties().output_partitioning().partition_count();
327        format!(
328            "CSV: {} [{} partitions]",
329            get_file_scan(exec.base_config()),
330            parts
331        )
332    } else if let Some(exec) = plan.as_any().downcast_ref::<NdJsonExec>() {
333        let parts = exec.properties().output_partitioning().partition_count();
334        format!("JSON [{parts} partitions]")
335    } else if let Some(exec) = plan.as_any().downcast_ref::<AvroExec>() {
336        let parts = exec.properties().output_partitioning().partition_count();
337        format!(
338            "Avro: {} [{} partitions]",
339            get_file_scan(exec.base_config()),
340            parts
341        )
342    } else if let Some(exec) = plan.as_any().downcast_ref::<ParquetExec>() {
343        let parts = exec.properties().output_partitioning().partition_count();
344        format!(
345            "Parquet: {} [{} partitions]",
346            get_file_scan(exec.base_config()),
347            parts
348        )
349    } else if let Some(exec) = plan.as_any().downcast_ref::<GlobalLimitExec>() {
350        format!(
351            "GlobalLimit(skip={}, fetch={:?})",
352            exec.skip(),
353            exec.fetch()
354        )
355    } else if let Some(exec) = plan.as_any().downcast_ref::<LocalLimitExec>() {
356        format!("LocalLimit({})", exec.fetch())
357    } else {
358        debug!(
359            "Unknown physical operator when producing DOT graph: {:?}",
360            plan
361        );
362        "Unknown Operator".to_string()
363    }
364}
365
366fn format_partitioning(x: Partitioning) -> String {
367    match x {
368        Partitioning::UnknownPartitioning(n) | Partitioning::RoundRobinBatch(n) => {
369            format!("{n} partitions")
370        }
371        Partitioning::Hash(expr, n) => {
372            format!("{} partitions, expr={}", n, format_expr_list(&expr))
373        }
374    }
375}
376
377fn format_expr_list(exprs: &[Arc<dyn PhysicalExpr>]) -> String {
378    let expr_strings: Vec<String> = exprs.iter().map(|e| format!("{e}")).collect();
379    expr_strings.join(", ")
380}
381
382/// Get summary of file scan locations
383fn get_file_scan(scan: &FileScanConfig) -> String {
384    if !scan.file_groups.is_empty() {
385        let partitioned_files: Vec<PartitionedFile> = scan
386            .file_groups
387            .iter()
388            .flat_map(|part_file| part_file.clone())
389            .collect();
390        let paths: Vec<Path> = partitioned_files
391            .iter()
392            .map(|part_file| part_file.object_meta.location.clone())
393            .collect();
394        match paths.len() {
395            0 => "No files found".to_owned(),
396            1 => {
397                // single file
398                format!("{}", paths[0])
399            }
400            _ => {
401                // multiple files so show parent directory
402                let path = format!("{}", paths[0]);
403                let path = if let Some(i) = path.rfind('/') {
404                    &path[0..i]
405                } else {
406                    &path
407                };
408                format!("{} [{} files]", path, paths.len())
409            }
410        }
411    } else {
412        "".to_string()
413    }
414}
415
416#[cfg(test)]
417mod tests {
418    use crate::state::execution_graph::ExecutionGraph;
419    use crate::state::execution_graph_dot::ExecutionGraphDot;
420    use kapot_core::error::{KapotError, Result};
421    use datafusion::arrow::datatypes::{DataType, Field, Schema};
422    use datafusion::datasource::MemTable;
423    use datafusion::prelude::{SessionConfig, SessionContext};
424    use std::sync::Arc;
425
426    #[tokio::test]
427    async fn dot() -> Result<()> {
428        let graph = test_graph().await?;
429        let dot = ExecutionGraphDot::generate(&graph)
430            .map_err(|e| KapotError::Internal(format!("{e:?}")))?;
431
432        let expected = r#"digraph G {
433	subgraph cluster0 {
434		label = "Stage 1 [Resolved]";
435		stage_1_0 [shape=box, label="ShuffleWriter [2 partitions]"]
436		stage_1_0_0 [shape=box, label="MemoryExec"]
437		stage_1_0_0 -> stage_1_0
438	}
439	subgraph cluster1 {
440		label = "Stage 2 [Resolved]";
441		stage_2_0 [shape=box, label="ShuffleWriter [2 partitions]"]
442		stage_2_0_0 [shape=box, label="MemoryExec"]
443		stage_2_0_0 -> stage_2_0
444	}
445	subgraph cluster2 {
446		label = "Stage 3 [Unresolved]";
447		stage_3_0 [shape=box, label="ShuffleWriter [48 partitions]"]
448		stage_3_0_0 [shape=box, label="CoalesceBatches [batchSize=4096]"]
449		stage_3_0_0_0 [shape=box, label="HashJoin
450join_expr=a@0 = a@0
451filter_expr="]
452		stage_3_0_0_0_0 [shape=box, label="CoalesceBatches [batchSize=4096]"]
453		stage_3_0_0_0_0_0 [shape=box, label="UnresolvedShuffleExec [stage_id=1]"]
454		stage_3_0_0_0_0_0 -> stage_3_0_0_0_0
455		stage_3_0_0_0_0 -> stage_3_0_0_0
456		stage_3_0_0_0_1 [shape=box, label="CoalesceBatches [batchSize=4096]"]
457		stage_3_0_0_0_1_0 [shape=box, label="UnresolvedShuffleExec [stage_id=2]"]
458		stage_3_0_0_0_1_0 -> stage_3_0_0_0_1
459		stage_3_0_0_0_1 -> stage_3_0_0_0
460		stage_3_0_0_0 -> stage_3_0_0
461		stage_3_0_0 -> stage_3_0
462	}
463	subgraph cluster3 {
464		label = "Stage 4 [Resolved]";
465		stage_4_0 [shape=box, label="ShuffleWriter [2 partitions]"]
466		stage_4_0_0 [shape=box, label="MemoryExec"]
467		stage_4_0_0 -> stage_4_0
468	}
469	subgraph cluster4 {
470		label = "Stage 5 [Unresolved]";
471		stage_5_0 [shape=box, label="ShuffleWriter [48 partitions]"]
472		stage_5_0_0 [shape=box, label="CoalesceBatches [batchSize=4096]"]
473		stage_5_0_0_0 [shape=box, label="HashJoin
474join_expr=b@3 = b@1
475filter_expr="]
476		stage_5_0_0_0_0 [shape=box, label="CoalesceBatches [batchSize=4096]"]
477		stage_5_0_0_0_0_0 [shape=box, label="UnresolvedShuffleExec [stage_id=3]"]
478		stage_5_0_0_0_0_0 -> stage_5_0_0_0_0
479		stage_5_0_0_0_0 -> stage_5_0_0_0
480		stage_5_0_0_0_1 [shape=box, label="CoalesceBatches [batchSize=4096]"]
481		stage_5_0_0_0_1_0 [shape=box, label="UnresolvedShuffleExec [stage_id=4]"]
482		stage_5_0_0_0_1_0 -> stage_5_0_0_0_1
483		stage_5_0_0_0_1 -> stage_5_0_0_0
484		stage_5_0_0_0 -> stage_5_0_0
485		stage_5_0_0 -> stage_5_0
486	}
487	stage_1_0 -> stage_3_0_0_0_0_0
488	stage_2_0 -> stage_3_0_0_0_1_0
489	stage_3_0 -> stage_5_0_0_0_0_0
490	stage_4_0 -> stage_5_0_0_0_1_0
491}
492"#;
493        assert_eq!(expected, &dot);
494        Ok(())
495    }
496
497    #[tokio::test]
498    async fn query_stage() -> Result<()> {
499        let graph = test_graph().await?;
500        let dot = ExecutionGraphDot::generate_for_query_stage(&graph, 3)
501            .map_err(|e| KapotError::Internal(format!("{e:?}")))?;
502
503        let expected = r#"digraph G {
504		stage_3_0 [shape=box, label="ShuffleWriter [48 partitions]"]
505		stage_3_0_0 [shape=box, label="CoalesceBatches [batchSize=4096]"]
506		stage_3_0_0_0 [shape=box, label="HashJoin
507join_expr=a@0 = a@0
508filter_expr="]
509		stage_3_0_0_0_0 [shape=box, label="CoalesceBatches [batchSize=4096]"]
510		stage_3_0_0_0_0_0 [shape=box, label="UnresolvedShuffleExec [stage_id=1]"]
511		stage_3_0_0_0_0_0 -> stage_3_0_0_0_0
512		stage_3_0_0_0_0 -> stage_3_0_0_0
513		stage_3_0_0_0_1 [shape=box, label="CoalesceBatches [batchSize=4096]"]
514		stage_3_0_0_0_1_0 [shape=box, label="UnresolvedShuffleExec [stage_id=2]"]
515		stage_3_0_0_0_1_0 -> stage_3_0_0_0_1
516		stage_3_0_0_0_1 -> stage_3_0_0_0
517		stage_3_0_0_0 -> stage_3_0_0
518		stage_3_0_0 -> stage_3_0
519}
520"#;
521        assert_eq!(expected, &dot);
522        Ok(())
523    }
524
525    #[tokio::test]
526    async fn dot_optimized() -> Result<()> {
527        let graph = test_graph_optimized().await?;
528        let dot = ExecutionGraphDot::generate(&graph)
529            .map_err(|e| KapotError::Internal(format!("{e:?}")))?;
530
531        let expected = r#"digraph G {
532	subgraph cluster0 {
533		label = "Stage 1 [Resolved]";
534		stage_1_0 [shape=box, label="ShuffleWriter [2 partitions]"]
535		stage_1_0_0 [shape=box, label="MemoryExec"]
536		stage_1_0_0 -> stage_1_0
537	}
538	subgraph cluster1 {
539		label = "Stage 2 [Resolved]";
540		stage_2_0 [shape=box, label="ShuffleWriter [2 partitions]"]
541		stage_2_0_0 [shape=box, label="MemoryExec"]
542		stage_2_0_0 -> stage_2_0
543	}
544	subgraph cluster2 {
545		label = "Stage 3 [Resolved]";
546		stage_3_0 [shape=box, label="ShuffleWriter [2 partitions]"]
547		stage_3_0_0 [shape=box, label="MemoryExec"]
548		stage_3_0_0 -> stage_3_0
549	}
550	subgraph cluster3 {
551		label = "Stage 4 [Unresolved]";
552		stage_4_0 [shape=box, label="ShuffleWriter [48 partitions]"]
553		stage_4_0_0 [shape=box, label="CoalesceBatches [batchSize=4096]"]
554		stage_4_0_0_0 [shape=box, label="HashJoin
555join_expr=a@1 = a@0
556filter_expr="]
557		stage_4_0_0_0_0 [shape=box, label="CoalesceBatches [batchSize=4096]"]
558		stage_4_0_0_0_0_0 [shape=box, label="HashJoin
559join_expr=a@0 = a@0
560filter_expr="]
561		stage_4_0_0_0_0_0_0 [shape=box, label="CoalesceBatches [batchSize=4096]"]
562		stage_4_0_0_0_0_0_0_0 [shape=box, label="UnresolvedShuffleExec [stage_id=1]"]
563		stage_4_0_0_0_0_0_0_0 -> stage_4_0_0_0_0_0_0
564		stage_4_0_0_0_0_0_0 -> stage_4_0_0_0_0_0
565		stage_4_0_0_0_0_0_1 [shape=box, label="CoalesceBatches [batchSize=4096]"]
566		stage_4_0_0_0_0_0_1_0 [shape=box, label="UnresolvedShuffleExec [stage_id=2]"]
567		stage_4_0_0_0_0_0_1_0 -> stage_4_0_0_0_0_0_1
568		stage_4_0_0_0_0_0_1 -> stage_4_0_0_0_0_0
569		stage_4_0_0_0_0_0 -> stage_4_0_0_0_0
570		stage_4_0_0_0_0 -> stage_4_0_0_0
571		stage_4_0_0_0_1 [shape=box, label="CoalesceBatches [batchSize=4096]"]
572		stage_4_0_0_0_1_0 [shape=box, label="UnresolvedShuffleExec [stage_id=3]"]
573		stage_4_0_0_0_1_0 -> stage_4_0_0_0_1
574		stage_4_0_0_0_1 -> stage_4_0_0_0
575		stage_4_0_0_0 -> stage_4_0_0
576		stage_4_0_0 -> stage_4_0
577	}
578	stage_1_0 -> stage_4_0_0_0_0_0_0_0
579	stage_2_0 -> stage_4_0_0_0_0_0_1_0
580	stage_3_0 -> stage_4_0_0_0_1_0
581}
582"#;
583        assert_eq!(expected, &dot);
584        Ok(())
585    }
586
587    #[tokio::test]
588    async fn query_stage_optimized() -> Result<()> {
589        let graph = test_graph_optimized().await?;
590        let dot = ExecutionGraphDot::generate_for_query_stage(&graph, 4)
591            .map_err(|e| KapotError::Internal(format!("{e:?}")))?;
592
593        let expected = r#"digraph G {
594		stage_4_0 [shape=box, label="ShuffleWriter [48 partitions]"]
595		stage_4_0_0 [shape=box, label="CoalesceBatches [batchSize=4096]"]
596		stage_4_0_0_0 [shape=box, label="HashJoin
597join_expr=a@1 = a@0
598filter_expr="]
599		stage_4_0_0_0_0 [shape=box, label="CoalesceBatches [batchSize=4096]"]
600		stage_4_0_0_0_0_0 [shape=box, label="HashJoin
601join_expr=a@0 = a@0
602filter_expr="]
603		stage_4_0_0_0_0_0_0 [shape=box, label="CoalesceBatches [batchSize=4096]"]
604		stage_4_0_0_0_0_0_0_0 [shape=box, label="UnresolvedShuffleExec [stage_id=1]"]
605		stage_4_0_0_0_0_0_0_0 -> stage_4_0_0_0_0_0_0
606		stage_4_0_0_0_0_0_0 -> stage_4_0_0_0_0_0
607		stage_4_0_0_0_0_0_1 [shape=box, label="CoalesceBatches [batchSize=4096]"]
608		stage_4_0_0_0_0_0_1_0 [shape=box, label="UnresolvedShuffleExec [stage_id=2]"]
609		stage_4_0_0_0_0_0_1_0 -> stage_4_0_0_0_0_0_1
610		stage_4_0_0_0_0_0_1 -> stage_4_0_0_0_0_0
611		stage_4_0_0_0_0_0 -> stage_4_0_0_0_0
612		stage_4_0_0_0_0 -> stage_4_0_0_0
613		stage_4_0_0_0_1 [shape=box, label="CoalesceBatches [batchSize=4096]"]
614		stage_4_0_0_0_1_0 [shape=box, label="UnresolvedShuffleExec [stage_id=3]"]
615		stage_4_0_0_0_1_0 -> stage_4_0_0_0_1
616		stage_4_0_0_0_1 -> stage_4_0_0_0
617		stage_4_0_0_0 -> stage_4_0_0
618		stage_4_0_0 -> stage_4_0
619}
620"#;
621        assert_eq!(expected, &dot);
622        Ok(())
623    }
624
625    async fn test_graph() -> Result<ExecutionGraph> {
626        let mut config = SessionConfig::new()
627            .with_target_partitions(48)
628            .with_batch_size(4096);
629        config
630            .options_mut()
631            .optimizer
632            .enable_round_robin_repartition = false;
633        let ctx = SessionContext::new_with_config(config);
634        let schema = Arc::new(Schema::new(vec![
635            Field::new("a", DataType::UInt32, false),
636            Field::new("b", DataType::UInt32, false),
637        ]));
638        let table = Arc::new(MemTable::try_new(schema.clone(), vec![vec![], vec![]])?);
639        ctx.register_table("foo", table.clone())?;
640        ctx.register_table("bar", table.clone())?;
641        ctx.register_table("baz", table)?;
642        let df = ctx
643            .sql("SELECT * FROM foo JOIN bar ON foo.a = bar.a JOIN baz on bar.b = baz.b")
644            .await?;
645        let plan = df.into_optimized_plan()?;
646        let plan = ctx.state().create_physical_plan(&plan).await?;
647        ExecutionGraph::new("scheduler_id", "job_id", "job_name", "session_id", plan, 0)
648    }
649
650    // With the improvement of https://github.com/apache/arrow-datafusion/pull/4122,
651    // Redundant RepartitionExec can be removed so that the stage number will be reduced
652    async fn test_graph_optimized() -> Result<ExecutionGraph> {
653        let mut config = SessionConfig::new()
654            .with_target_partitions(48)
655            .with_batch_size(4096);
656        config
657            .options_mut()
658            .optimizer
659            .enable_round_robin_repartition = false;
660        let ctx = SessionContext::new_with_config(config);
661        let schema =
662            Arc::new(Schema::new(vec![Field::new("a", DataType::UInt32, false)]));
663        // we specify the input partitions to be > 1 because of https://github.com/apache/datafusion/issues/12611
664        let table = Arc::new(MemTable::try_new(schema.clone(), vec![vec![], vec![]])?);
665        ctx.register_table("foo", table.clone())?;
666        ctx.register_table("bar", table.clone())?;
667        ctx.register_table("baz", table)?;
668        let df = ctx
669            .sql("SELECT * FROM foo JOIN bar ON foo.a = bar.a JOIN baz on bar.a = baz.a")
670            .await?;
671        let plan = df.into_optimized_plan()?;
672        let plan = ctx.state().create_physical_plan(&plan).await?;
673        ExecutionGraph::new("scheduler_id", "job_id", "job_name", "session_id", plan, 0)
674    }
675}