python_to_mermaid/
mermaid.rs

1use std::{cell::RefCell, fmt, rc::Rc};
2
3use itertools::Itertools as _;
4
5#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
6pub struct NodeId(String);
7
8impl fmt::Display for NodeId {
9    fn fmt(&self, b: &mut fmt::Formatter) -> fmt::Result {
10        write!(b, "{}", self.0)
11    }
12}
13
14#[derive(Debug, Clone)]
15struct NodeIdGenerator {
16    digits: Vec<char>,
17}
18
19impl NodeIdGenerator {
20    pub fn new() -> Self {
21        Self { digits: vec![] }
22    }
23
24    pub fn step(&mut self) {
25        for i in (0..self.digits.len()).rev() {
26            if self.digits[i] != 'Z' {
27                self.digits[i] = (self.digits[i] as u8 + 1) as char;
28                return;
29            }
30
31            self.digits[i] = 'A';
32        }
33
34        // All characters are 'Z' if we reach this point
35        self.digits.push('A');
36    }
37
38    pub fn generate(&mut self) -> NodeId {
39        self.step();
40        NodeId(self.to_string())
41    }
42}
43
44impl fmt::Display for NodeIdGenerator {
45    fn fmt(&self, b: &mut fmt::Formatter) -> fmt::Result {
46        write!(b, "{}", self.digits.iter().format(""))
47    }
48}
49
50impl Iterator for NodeIdGenerator {
51    type Item = NodeId;
52
53    fn next(&mut self) -> Option<Self::Item> {
54        self.step();
55        Some(self.generate())
56    }
57}
58
59#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
60pub struct Node {
61    id: NodeId,
62    label: String,
63    shape: NodeShape,
64}
65
66impl fmt::Display for Node {
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68        match self.shape {
69            NodeShape::Rounded => write!(f, r#"{}("{}")"#, self.id, self.label),
70            NodeShape::Rectangle => write!(f, r#"{}["{}"]"#, self.id, self.label),
71            NodeShape::Diamond => write!(f, r#"{}{{"{}"}}"#, self.id, self.label),
72            NodeShape::Trapezoid => write!(f, r#"{}[/"{}"\]"#, self.id, self.label),
73            NodeShape::InvertedTrapezoid => write!(f, r#"{}[\"{}"/]"#, self.id, self.label),
74        }
75    }
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
79pub enum NodeShape {
80    Rounded,
81    Rectangle,
82    Diamond,
83    Trapezoid,
84    InvertedTrapezoid,
85}
86
87#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
88pub struct Edge {
89    id0: NodeId,
90    id1: NodeId,
91    label: Option<String>,
92}
93
94#[derive(Debug, Clone)]
95pub struct MermaidFlowchart {
96    graph: MermaidGraph,
97}
98
99#[derive(Debug, Clone)]
100pub struct MermaidGraph {
101    id_gen: Rc<RefCell<NodeIdGenerator>>,
102    label: Option<String>,
103    nodes: Vec<Node>,
104    edges: Vec<Edge>,
105    subgraphs: Vec<MermaidGraph>,
106}
107
108impl MermaidFlowchart {
109    pub fn new() -> Self {
110        let id_gen = Rc::new(RefCell::new(NodeIdGenerator::new()));
111        let graph = MermaidGraph::new(Rc::clone(&id_gen));
112        Self { graph }
113    }
114
115    pub fn graph_mut(&mut self) -> &mut MermaidGraph {
116        &mut self.graph
117    }
118
119    pub fn render<W: fmt::Write>(self, writer: &mut W) {
120        writeln!(writer, "flowchart TD;").unwrap();
121        self.graph.render(writer);
122    }
123}
124
125impl Default for MermaidFlowchart {
126    fn default() -> Self {
127        Self::new()
128    }
129}
130
131impl MermaidGraph {
132    fn new(id_gen: Rc<RefCell<NodeIdGenerator>>) -> Self {
133        Self {
134            id_gen,
135            label: None,
136            nodes: vec![],
137            edges: vec![],
138            subgraphs: vec![],
139        }
140    }
141
142    fn with_label(id_gen: Rc<RefCell<NodeIdGenerator>>, label: String) -> Self {
143        Self {
144            id_gen,
145            label: Some(label),
146            nodes: vec![],
147            edges: vec![],
148            subgraphs: vec![],
149        }
150    }
151
152    pub fn add_node(&mut self, label: impl Into<String>, shape: NodeShape) -> NodeId {
153        let label = label.into();
154        let id = self.id_gen.borrow_mut().generate();
155        self.nodes.push(Node {
156            id: id.clone(),
157            label,
158            shape,
159        });
160
161        id
162    }
163
164    pub fn add_edge(&mut self, id0: &NodeId, id1: &NodeId, label: Option<&str>) {
165        self.edges.push(Edge {
166            id0: id0.clone(),
167            id1: id1.clone(),
168            label: label.map(|s| s.to_string()),
169        });
170    }
171
172    pub fn add_subgraph(&mut self, label: String) -> &mut MermaidGraph {
173        let subgraph = MermaidGraph::with_label(Rc::clone(&self.id_gen), label);
174        self.subgraphs.push(subgraph);
175        self.subgraphs.last_mut().unwrap()
176    }
177
178    pub fn render<W: fmt::Write>(self, writer: &mut W) {
179        self.render_nodes(writer);
180        self.render_edges(writer);
181    }
182
183    fn render_nodes<W: fmt::Write>(&self, writer: &mut W) {
184        for node in &self.nodes {
185            writeln!(writer, "{};", node).unwrap();
186        }
187
188        for subgraph in &self.subgraphs {
189            writeln!(
190                writer,
191                "subgraph \"{}\"",
192                subgraph.label.as_deref().unwrap_or("")
193            )
194            .unwrap();
195            subgraph.render_nodes(writer);
196            writeln!(writer, "end").unwrap();
197        }
198    }
199
200    fn render_edges<W: fmt::Write>(&self, writer: &mut W) {
201        for edge in &self.edges {
202            if let Some(label) = &edge.label {
203                writeln!(writer, r#"{} -->|"{}"| {};"#, edge.id0, label, edge.id1).unwrap();
204            } else {
205                writeln!(writer, "{} --> {};", edge.id0, edge.id1).unwrap();
206            }
207        }
208
209        for subgraph in &self.subgraphs {
210            // Do not include edges within `subgraph ... end` block
211            subgraph.render_edges(writer);
212        }
213    }
214}