use std::collections::HashSet;
use crate::graph::{Arrow, Direction, Edge, Graph, Node, NotePosition, Shape, Stroke, Subgraph};
use crate::mermaid::state::{
StateDecl, StateModel, StateStatement, StateStereotype, StateTransition,
};
pub fn compile(model: &StateModel) -> Graph {
let mut graph = Graph::new(direction_from_str(model.direction.as_deref()));
let mut seen_nodes: HashSet<String> = HashSet::new();
let mut note_counter: usize = 0;
process_statements(
&mut graph,
&mut seen_nodes,
&mut note_counter,
&model.statements,
None,
"__root",
);
resolve_subgraph_edges(&mut graph);
graph
}
fn add_note_node(
graph: &mut Graph,
state_id: &str,
text: &str,
position: NotePosition,
index: usize,
) {
let note_node_id = format!("{state_id}____note_{index}");
let state_parent = graph.nodes.get(state_id).and_then(|n| n.parent.clone());
let mut note_node = Node::new(¬e_node_id)
.with_label(text)
.with_shape(Shape::NoteRect);
note_node.parent = state_parent.clone();
graph.add_node(note_node);
if let Some(ref parent_id) = state_parent
&& let Some(parent_sg) = graph.subgraphs.get_mut(parent_id)
{
parent_sg.nodes.push(note_node_id.clone());
}
let (from, to) = match position {
NotePosition::Right => (state_id.to_string(), note_node_id.clone()),
NotePosition::Left => (note_node_id.clone(), state_id.to_string()),
};
graph.add_edge(
Edge::new(&from, &to)
.with_stroke(Stroke::Dashed)
.with_arrows(Arrow::None, Arrow::None),
);
}
fn direction_from_str(dir: Option<&str>) -> Direction {
match dir {
Some("LR") => Direction::LeftRight,
Some("RL") => Direction::RightLeft,
Some("BT") => Direction::BottomTop,
Some("TB") | Some("TD") => Direction::TopDown,
_ => Direction::TopDown,
}
}
fn process_statements(
graph: &mut Graph,
seen_nodes: &mut HashSet<String>,
note_counter: &mut usize,
statements: &[StateStatement],
parent_subgraph: Option<&str>,
scope: &str,
) {
for stmt in statements {
match stmt {
StateStatement::Transition(t) => {
add_transition(graph, seen_nodes, t, parent_subgraph, scope);
}
StateStatement::State(decl) => {
process_state_decl(
graph,
seen_nodes,
note_counter,
decl,
parent_subgraph,
scope,
);
}
StateStatement::Direction(_) => {
}
StateStatement::Note(note) => {
ensure_implicit_node(graph, seen_nodes, ¬e.state_id, parent_subgraph);
let position = match note.position {
crate::mermaid::state::NotePosition::Right => NotePosition::Right,
crate::mermaid::state::NotePosition::Left => NotePosition::Left,
};
add_note_node(graph, ¬e.state_id, ¬e.text, position, *note_counter);
*note_counter += 1;
}
}
}
}
fn process_state_decl(
graph: &mut Graph,
seen_nodes: &mut HashSet<String>,
note_counter: &mut usize,
decl: &StateDecl,
parent_subgraph: Option<&str>,
_scope: &str,
) {
let is_composite = !decl.children.is_empty();
if is_composite {
let dir = decl.children.iter().find_map(|s| match s {
StateStatement::Direction(d) => Some(direction_from_str(Some(d))),
_ => None,
});
let child_ids = collect_child_node_ids(&decl.children, &decl.id);
process_statements(
graph,
seen_nodes,
note_counter,
&decl.children,
Some(&decl.id),
&decl.id, );
for child_id in &child_ids {
if let Some(node) = graph.nodes.get_mut(child_id) {
node.parent = Some(decl.id.clone());
}
}
graph.subgraphs.insert(
decl.id.clone(),
Subgraph {
id: decl.id.clone(),
title: decl.alias.as_deref().unwrap_or(&decl.id).to_string(),
nodes: child_ids,
parent: parent_subgraph.map(|s| s.to_string()),
dir,
invisible: false,
},
);
graph.subgraph_order.push(decl.id.clone());
} else {
ensure_state_node_with_decl(graph, seen_nodes, decl, parent_subgraph);
}
}
fn collect_child_node_ids(statements: &[StateStatement], scope: &str) -> Vec<String> {
let mut ids = Vec::new();
let mut seen = HashSet::new();
for stmt in statements {
match stmt {
StateStatement::Transition(t) => {
let from = if t.from == "[*]" {
star_node_id(scope, true)
} else {
t.from.clone()
};
let to = if t.to == "[*]" {
star_node_id(scope, false)
} else {
t.to.clone()
};
if seen.insert(from.clone()) {
ids.push(from);
}
if seen.insert(to.clone()) {
ids.push(to);
}
}
StateStatement::State(decl) if decl.children.is_empty() => {
if seen.insert(decl.id.clone()) {
ids.push(decl.id.clone());
}
}
_ => {}
}
}
ids
}
fn star_node_id(scope: &str, is_source: bool) -> String {
let suffix = if is_source { "start" } else { "end" };
format!("{scope}_{suffix}")
}
fn ensure_state_node_with_decl(
graph: &mut Graph,
seen_nodes: &mut HashSet<String>,
decl: &StateDecl,
parent: Option<&str>,
) {
let shape = match &decl.stereotype {
Some(StateStereotype::Fork | StateStereotype::Join) => Shape::ForkJoin,
Some(StateStereotype::Choice) => Shape::Diamond,
None => Shape::Round,
};
let display_name = decl.alias.as_deref().unwrap_or(&decl.id);
let is_unlabeled_shape = shape == Shape::ForkJoin || shape == Shape::Diamond;
if seen_nodes.contains(&decl.id) {
if let Some(node) = graph.nodes.get_mut(&decl.id) {
if shape != Shape::Round {
node.shape = shape;
node.label = String::new();
}
if !decl.descriptions.is_empty() && !is_unlabeled_shape {
append_descriptions(&mut node.label, &decl.descriptions, display_name);
}
if parent.is_some() && node.parent.is_none() {
node.parent = parent.map(|s| s.to_string());
}
}
} else {
let label = if is_unlabeled_shape {
String::new()
} else {
build_description_label(&decl.descriptions, display_name)
};
let mut node = Node::new(&decl.id).with_label(label).with_shape(shape);
node.parent = parent.map(|s| s.to_string());
graph.add_node(node);
seen_nodes.insert(decl.id.clone());
}
}
fn resolve_star_node(
graph: &mut Graph,
seen_nodes: &mut HashSet<String>,
is_source: bool,
parent: Option<&str>,
scope: &str,
) -> String {
let id = star_node_id(scope, is_source);
let shape = if is_source {
Shape::SmallCircle
} else {
Shape::FramedCircle
};
if !seen_nodes.contains(&id) {
let mut node = Node::new(&id).with_label("").with_shape(shape);
node.parent = parent.map(|s| s.to_string());
graph.add_node(node);
seen_nodes.insert(id.clone());
}
id
}
fn build_description_label(descriptions: &[String], display_name: &str) -> String {
match descriptions.len() {
0 => display_name.to_string(),
1 => descriptions[0].clone(),
_ => {
let mut parts = vec![descriptions[0].clone(), Node::SEPARATOR.to_string()];
parts.extend(descriptions[1..].iter().cloned());
parts.join("\n")
}
}
}
fn append_descriptions(label: &mut String, new_descs: &[String], display_name: &str) {
if new_descs.is_empty() {
return;
}
if *label == display_name {
*label = build_description_label(new_descs, display_name);
return;
}
let has_separator = label.contains(Node::SEPARATOR);
if has_separator {
for desc in new_descs {
label.push('\n');
label.push_str(desc);
}
} else {
let existing = label.clone();
let mut all = vec![existing];
all.extend(new_descs.iter().cloned());
*label = build_description_label(&all, display_name);
}
}
fn ensure_implicit_node(
graph: &mut Graph,
seen_nodes: &mut HashSet<String>,
id: &str,
parent: Option<&str>,
) {
if !seen_nodes.contains(id) {
let mut node = Node::new(id).with_shape(Shape::Round);
node.parent = parent.map(|s| s.to_string());
graph.add_node(node);
seen_nodes.insert(id.to_string());
}
}
fn add_transition(
graph: &mut Graph,
seen_nodes: &mut HashSet<String>,
t: &StateTransition,
parent: Option<&str>,
scope: &str,
) {
let from_id = if t.from == "[*]" {
resolve_star_node(graph, seen_nodes, true, parent, scope)
} else {
ensure_implicit_node(graph, seen_nodes, &t.from, parent);
t.from.clone()
};
let to_id = if t.to == "[*]" {
resolve_star_node(graph, seen_nodes, false, parent, scope)
} else {
ensure_implicit_node(graph, seen_nodes, &t.to, parent);
t.to.clone()
};
let mut edge = Edge::new(&from_id, &to_id).with_arrows(Arrow::None, Arrow::Normal);
if let Some(label) = &t.label {
edge = edge.with_label(label);
}
graph.add_edge(edge);
}
fn resolve_subgraph_edges(graph: &mut Graph) {
let mut resolved_edges = Vec::new();
for edge in &graph.edges {
let (from, from_subgraph) = if graph.is_subgraph(&edge.from) {
match graph.find_subgraph_sink(&edge.from) {
Some(child) => (child, Some(edge.from.clone())),
None => continue,
}
} else {
(edge.from.clone(), None)
};
let (to, to_subgraph) = if graph.is_subgraph(&edge.to) {
match graph.find_non_cluster_child(&edge.to) {
Some(child) => (child, Some(edge.to.clone())),
None => continue,
}
} else {
(edge.to.clone(), None)
};
resolved_edges.push(Edge {
from,
to,
from_subgraph,
to_subgraph,
stroke: edge.stroke,
arrow_start: edge.arrow_start,
arrow_end: edge.arrow_end,
label: edge.label.clone(),
head_label: edge.head_label.clone(),
tail_label: edge.tail_label.clone(),
minlen: edge.minlen,
index: edge.index,
});
}
graph.edges = resolved_edges;
let subgraph_ids: Vec<String> = graph.subgraphs.keys().cloned().collect();
for sg_id in &subgraph_ids {
if graph.nodes.contains_key(sg_id) {
graph.nodes.remove(sg_id);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mermaid::state::parse_state_diagram;
fn compile_state(input: &str) -> Graph {
let model = parse_state_diagram(input).unwrap();
compile(&model)
}
#[test]
fn compiler_basic_transition_creates_nodes_and_edge() {
let graph = compile_state("stateDiagram-v2\n A --> B");
assert!(graph.nodes.contains_key("A"));
assert!(graph.nodes.contains_key("B"));
assert_eq!(graph.edges.len(), 1);
assert_eq!(graph.edges[0].from, "A");
assert_eq!(graph.edges[0].to, "B");
}
#[test]
fn compiler_state_nodes_are_round() {
let graph = compile_state("stateDiagram-v2\n A --> B");
assert_eq!(graph.nodes["A"].shape, Shape::Round);
assert_eq!(graph.nodes["B"].shape, Shape::Round);
}
#[test]
fn compiler_star_source_becomes_small_circle() {
let graph = compile_state("stateDiagram-v2\n [*] --> Idle");
let start_node = graph.nodes.values().find(|n| n.shape == Shape::SmallCircle);
assert!(start_node.is_some());
assert_eq!(graph.edges[0].to, "Idle");
}
#[test]
fn compiler_star_target_becomes_framed_circle() {
let graph = compile_state("stateDiagram-v2\n Done --> [*]");
let end_node = graph
.nodes
.values()
.find(|n| n.shape == Shape::FramedCircle);
assert!(end_node.is_some());
assert_eq!(graph.edges[0].from, "Done");
}
#[test]
fn compiler_transition_label_preserved() {
let graph = compile_state("stateDiagram-v2\n A --> B : submit");
assert_eq!(graph.edges[0].label, Some("submit".to_string()));
}
#[test]
fn compiler_default_direction_is_top_down() {
let graph = compile_state("stateDiagram-v2\n A --> B");
assert_eq!(graph.direction, Direction::TopDown);
}
#[test]
fn compiler_direction_lr() {
let graph = compile_state("stateDiagram-v2\n direction LR\n A --> B");
assert_eq!(graph.direction, Direction::LeftRight);
}
#[test]
fn compiler_fork_stereotype_uses_fork_join_shape() {
let graph =
compile_state("stateDiagram-v2\n state forkNode <<fork>>\n A --> forkNode");
assert_eq!(graph.nodes["forkNode"].shape, Shape::ForkJoin);
}
#[test]
fn compiler_choice_stereotype_uses_diamond_shape() {
let graph =
compile_state("stateDiagram-v2\n state choiceNode <<choice>>\n A --> choiceNode");
assert_eq!(graph.nodes["choiceNode"].shape, Shape::Diamond);
}
#[test]
fn compiler_star_markers_coalesce_per_scope() {
let input = "\
stateDiagram-v2
[*] --> A
[*] --> B
A --> [*]
B --> [*]";
let graph = compile_state(input);
let start_nodes: Vec<_> = graph
.nodes
.values()
.filter(|n| n.shape == Shape::SmallCircle)
.collect();
let end_nodes: Vec<_> = graph
.nodes
.values()
.filter(|n| n.shape == Shape::FramedCircle)
.collect();
assert_eq!(start_nodes.len(), 1);
assert_eq!(end_nodes.len(), 1);
assert_eq!(
graph
.edges
.iter()
.filter(|e| e.from == start_nodes[0].id)
.count(),
2
);
assert_eq!(
graph
.edges
.iter()
.filter(|e| e.to == end_nodes[0].id)
.count(),
2
);
}
#[test]
fn compiler_composite_gets_own_star_scope() {
let input = "\
stateDiagram-v2
[*] --> Active
state Active {
[*] --> Running
Running --> [*]
}
Active --> [*]";
let graph = compile_state(input);
let start_nodes: Vec<_> = graph
.nodes
.values()
.filter(|n| n.shape == Shape::SmallCircle)
.collect();
let end_nodes: Vec<_> = graph
.nodes
.values()
.filter(|n| n.shape == Shape::FramedCircle)
.collect();
assert_eq!(start_nodes.len(), 2);
assert_eq!(end_nodes.len(), 2);
}
#[test]
fn compiler_composite_state_creates_subgraph() {
let input = "\
stateDiagram-v2
[*] --> Active
state Active {
[*] --> Running
Running --> [*]
}
Active --> [*]";
let graph = compile_state(input);
assert!(graph.subgraphs.contains_key("Active"));
let sg = &graph.subgraphs["Active"];
assert_eq!(sg.title, "Active");
assert!(sg.parent.is_none());
assert!(
graph
.nodes
.values()
.any(|n| n.parent.as_deref() == Some("Active"))
);
}
#[test]
fn compiler_composite_direction_override() {
let input = "\
stateDiagram-v2
state Processing {
direction LR
[*] --> Validating
Validating --> [*]
}";
let graph = compile_state(input);
let sg = &graph.subgraphs["Processing"];
assert_eq!(sg.dir, Some(Direction::LeftRight));
}
#[test]
fn compiler_state_description_replaces_label() {
let input = "\
stateDiagram-v2
Active : The system is active
[*] --> Active";
let graph = compile_state(input);
assert_eq!(graph.nodes["Active"].label, "The system is active");
}
#[test]
fn compiler_stereotype_overrides_implicit_shape() {
let input = "\
stateDiagram-v2
A --> forkNode
state forkNode <<fork>>";
let graph = compile_state(input);
assert_eq!(graph.nodes["forkNode"].shape, Shape::ForkJoin);
}
#[test]
fn compiler_full_example() {
let input = "\
stateDiagram-v2
[*] --> Idle
Idle --> Processing : submit
Processing --> Done : complete
Done --> [*]";
let graph = compile_state(input);
assert_eq!(graph.nodes.len(), 5);
assert_eq!(graph.edges.len(), 4);
}
#[test]
fn compiler_multiline_descriptions_create_separator() {
let input = "\
stateDiagram-v2
Server : Listening on port 8080
Server : Accepts TCP connections
[*] --> Server";
let graph = compile_state(input);
let label = &graph.nodes["Server"].label;
let lines: Vec<&str> = label.lines().collect();
assert_eq!(lines[0], "Listening on port 8080");
assert_eq!(lines[1], Node::SEPARATOR);
assert_eq!(lines[2], "Accepts TCP connections");
}
#[test]
fn compiler_single_description_no_separator() {
let input = "\
stateDiagram-v2
Active : The system is active
[*] --> Active";
let graph = compile_state(input);
assert!(!graph.nodes["Active"].label.contains(Node::SEPARATOR));
assert_eq!(graph.nodes["Active"].label, "The system is active");
}
#[test]
fn compiler_three_descriptions() {
let input = "\
stateDiagram-v2
Server : Line 1
Server : Line 2
Server : Line 3";
let graph = compile_state(input);
let label = &graph.nodes["Server"].label;
let lines: Vec<&str> = label.lines().collect();
assert_eq!(lines.len(), 4); assert_eq!(lines[0], "Line 1");
assert_eq!(lines[1], Node::SEPARATOR);
assert_eq!(lines[2], "Line 2");
assert_eq!(lines[3], "Line 3");
}
#[test]
fn compiler_description_after_implicit_creation() {
let input = "\
stateDiagram-v2
[*] --> Server
Server : Listening
Server : Accepting";
let graph = compile_state(input);
let label = &graph.nodes["Server"].label;
assert!(label.contains(Node::SEPARATOR));
let lines: Vec<&str> = label.lines().collect();
assert_eq!(lines[0], "Listening");
assert_eq!(lines[1], Node::SEPARATOR);
assert_eq!(lines[2], "Accepting");
}
#[test]
fn compiler_note_creates_standalone_node_with_constraint_edge() {
let input = "\
stateDiagram-v2
[*] --> Active
note right of Active : This is a note";
let graph = compile_state(input);
let note_node = graph
.nodes
.values()
.find(|n| n.shape == Shape::NoteRect)
.expect("note node should exist");
assert_eq!(note_node.label, "This is a note");
assert!(
!graph.subgraphs.values().any(|sg| sg.invisible),
"should not create invisible subgraphs"
);
assert_eq!(note_node.parent, graph.nodes["Active"].parent);
let dotted_edge = graph
.edges
.iter()
.find(|e| e.stroke == Stroke::Dashed)
.expect("dotted edge should exist");
assert_eq!(dotted_edge.arrow_start, Arrow::None);
assert_eq!(dotted_edge.arrow_end, Arrow::None);
assert_eq!(dotted_edge.from, "Active");
assert_eq!(dotted_edge.to, note_node.id);
assert!(graph.notes.is_empty());
}
#[test]
fn compiler_note_multiline() {
let input = "\
stateDiagram-v2
Active --> [*]
note right of Active
Line one
Line two
end note";
let graph = compile_state(input);
let note_node = graph
.nodes
.values()
.find(|n| n.shape == Shape::NoteRect)
.expect("note node should exist");
assert_eq!(note_node.label, "Line one\nLine two");
}
}