use crate::{
Error,
types::{Direction, Edge, EdgeEndpoint, EdgeStyle, Graph, Node, NodeShape, Subgraph},
};
pub fn parse(input: &str) -> Result<Graph, Error> {
let normalised = input.replace('\n', ";").replace('\r', "");
let statements: Vec<&str> = normalised
.split(';')
.map(str::trim)
.filter(|s| !s.is_empty() && !s.starts_with("%%"))
.collect();
let mut iter = statements.iter().copied();
let direction = parse_header_stmt(&mut iter)?;
let mut graph = Graph::new(direction);
let remaining: Vec<&str> = iter.collect();
parse_statements(&remaining, &mut graph, &mut None);
Ok(graph)
}
fn parse_header_stmt<'a>(stmts: &mut impl Iterator<Item = &'a str>) -> Result<Direction, Error> {
let stmt = stmts
.next()
.ok_or_else(|| Error::ParseError("no 'graph'/'flowchart' header found".to_string()))?;
let mut parts = stmt.splitn(3, |c: char| c.is_whitespace());
let keyword = parts.next().unwrap_or("").to_lowercase();
if keyword != "graph" && keyword != "flowchart" {
return Err(Error::ParseError(format!(
"expected 'graph' or 'flowchart', got '{keyword}'"
)));
}
let dir_str = parts
.next()
.map(str::trim)
.filter(|s| !s.is_empty())
.unwrap_or("TD");
Direction::parse(dir_str)
.ok_or_else(|| Error::ParseError(format!("unknown direction '{dir_str}'")))
}
fn parse_statements(
stmts: &[&str],
graph: &mut Graph,
current_subgraph_id: &mut Option<String>,
) -> usize {
let mut i = 0;
while i < stmts.len() {
let stmt = stmts[i];
let first_word = stmt.split_whitespace().next().unwrap_or("");
match first_word {
"subgraph" => {
let (sg_id, sg_label) = parse_subgraph_header(stmt);
if let Some(ref parent_id) = current_subgraph_id.clone() {
if let Some(parent) = graph.subgraphs.iter_mut().find(|s| &s.id == parent_id) {
parent.subgraph_ids.push(sg_id.clone());
}
}
graph.subgraphs.push(Subgraph::new(sg_id.clone(), sg_label));
let mut inner_sg = Some(sg_id);
i += 1;
let consumed = parse_statements(&stmts[i..], graph, &mut inner_sg);
i += consumed;
}
"end" => {
return i + 1;
}
"direction" => {
if let Some(ref sg_id) = current_subgraph_id.clone() {
let dir_word = stmt.split_whitespace().nth(1).unwrap_or("");
if let Some(dir) = Direction::parse(dir_word)
&& let Some(sg) = graph.subgraphs.iter_mut().find(|s| s.id == *sg_id)
{
sg.direction = Some(dir);
}
}
i += 1;
}
"style" | "classDef" | "class" | "click" | "linkStyle" | "accTitle" | "accDescr" => {
i += 1;
}
_ => {
parse_statement(stmt, graph, current_subgraph_id);
i += 1;
}
}
}
stmts.len()
}
fn parse_statement(stmt: &str, graph: &mut Graph, current_subgraph_id: &mut Option<String>) {
if looks_like_edge_chain(stmt) {
parse_edge_chain(stmt, graph, current_subgraph_id);
} else {
if let Some(node) = parse_node_definition(stmt) {
let node_id = node.id.clone();
graph.upsert_node(node);
register_node_in_subgraph(graph, &node_id, current_subgraph_id);
}
}
}
fn register_node_in_subgraph(
graph: &mut Graph,
node_id: &str,
current_subgraph_id: &Option<String>,
) {
if let Some(sg_id) = current_subgraph_id
&& let Some(sg) = graph.subgraphs.iter_mut().find(|s| s.id == *sg_id)
&& !sg.node_ids.contains(&node_id.to_string())
{
sg.node_ids.push(node_id.to_string());
}
}
fn parse_subgraph_header(stmt: &str) -> (String, String) {
let rest = stmt.trim_start_matches("subgraph").trim();
if rest.is_empty() {
return ("__sg__".to_string(), "".to_string());
}
if let Some(bracket_pos) = rest.find('[') {
let id = rest[..bracket_pos].trim().to_string();
let rest_after = &rest[bracket_pos + 1..];
let label = if let Some(close) = rest_after.find(']') {
rest_after[..close].trim().to_string()
} else {
rest_after.trim().to_string()
};
let id = if id.is_empty() { label.clone() } else { id };
return (id, label);
}
let id = rest.to_string();
(id.clone(), id)
}
fn looks_like_edge_chain(s: &str) -> bool {
s.contains("-->")
|| s.contains("---")
|| s.contains("-.->")
|| s.contains("==>")
|| s.contains("<-->")
|| s.contains("--o")
|| s.contains("--x")
|| s.contains("-- ") || s.contains("--") }
fn parse_edge_chain(stmt: &str, graph: &mut Graph, current_subgraph_id: &mut Option<String>) {
let tokens = tokenise_chain(stmt);
if tokens.is_empty() {
return;
}
let mut i = 0;
let mut prev_id: Option<String> = None;
let mut pending_edge_label: Option<String> = None;
let mut pending_edge_style = EdgeStyle::Solid;
let mut pending_edge_start = EdgeEndpoint::None;
let mut pending_edge_end = EdgeEndpoint::Arrow;
while i < tokens.len() {
let tok = tokens[i].trim();
if i % 2 == 0 {
if tok.is_empty() {
i += 1;
continue;
}
let node = parse_node_definition(tok).unwrap_or_else(|| {
Node::new(tok, tok, NodeShape::Rectangle)
});
let node_id = node.id.clone();
graph.upsert_node(node);
register_node_in_subgraph(graph, &node_id, current_subgraph_id);
if let Some(ref from) = prev_id {
let edge = Edge::new_styled(
from.clone(),
node_id.clone(),
pending_edge_label.take(),
pending_edge_style,
pending_edge_start,
pending_edge_end,
);
graph.edges.push(edge);
pending_edge_style = EdgeStyle::Solid;
pending_edge_start = EdgeEndpoint::None;
pending_edge_end = EdgeEndpoint::Arrow;
}
prev_id = Some(node_id);
} else {
let (style, start, end) = classify_arrow(tok);
pending_edge_style = style;
pending_edge_start = start;
pending_edge_end = end;
pending_edge_label = extract_arrow_label(tok);
}
i += 1;
}
}
fn tokenise_chain(stmt: &str) -> Vec<String> {
let mut tokens: Vec<String> = Vec::new();
let chars: Vec<char> = stmt.chars().collect();
let len = chars.len();
let mut i = 0;
let mut current = String::new();
while i < len {
let ch = chars[i];
let is_potential_arrow_start = (ch == '-' || ch == '=' || ch == '<')
&& !current.trim().is_empty();
if is_potential_arrow_start && is_arrow_start(&chars, i) {
tokens.push(current.trim().to_string());
current = String::new();
let (arrow_tok, consumed) = consume_arrow(&chars, i);
tokens.push(arrow_tok);
i += consumed;
continue;
}
current.push(ch);
i += 1;
}
let last = current.trim().to_string();
if !last.is_empty() {
tokens.push(last);
}
tokens
}
fn is_arrow_start(chars: &[char], i: usize) -> bool {
let remaining: String = chars[i..].iter().collect();
remaining.starts_with("-->")
|| remaining.starts_with("---")
|| remaining.starts_with("-.->")
|| remaining.starts_with("==>")
|| remaining.starts_with("<-->")
|| remaining.starts_with("--o")
|| remaining.starts_with("--x")
|| remaining.starts_with("-- ") || remaining.starts_with("--")
}
fn classify_arrow(arrow: &str) -> (EdgeStyle, EdgeEndpoint, EdgeEndpoint) {
let base = if let Some(pipe) = arrow.find('|') {
&arrow[..pipe]
} else {
arrow
}
.trim();
if base.starts_with('<') && base.ends_with('>') {
return (EdgeStyle::Solid, EdgeEndpoint::Arrow, EdgeEndpoint::Arrow);
}
if base.ends_with('o') && base.starts_with('-') {
return (EdgeStyle::Solid, EdgeEndpoint::None, EdgeEndpoint::Circle);
}
if base.ends_with('x') && base.starts_with('-') {
return (EdgeStyle::Solid, EdgeEndpoint::None, EdgeEndpoint::Cross);
}
if base.starts_with('=') {
let has_arrow = base.ends_with('>');
let end = if has_arrow { EdgeEndpoint::Arrow } else { EdgeEndpoint::None };
return (EdgeStyle::Thick, EdgeEndpoint::None, end);
}
if base.contains(".-") || base.contains("-.") {
let has_arrow = base.ends_with('>');
let end = if has_arrow { EdgeEndpoint::Arrow } else { EdgeEndpoint::None };
return (EdgeStyle::Dotted, EdgeEndpoint::None, end);
}
if base.starts_with('-') && !base.ends_with('>') && !base.ends_with('o') && !base.ends_with('x') {
return (EdgeStyle::Solid, EdgeEndpoint::None, EdgeEndpoint::None);
}
(EdgeStyle::Solid, EdgeEndpoint::None, EdgeEndpoint::Arrow)
}
fn consume_arrow(chars: &[char], start: usize) -> (String, usize) {
let remaining: String = chars[start..].iter().collect();
if let Some(rest) = remaining.strip_prefix("<-->") {
let (label_part, extra) = try_consume_pipe_label(rest);
let tok = format!("<-->{label_part}");
return (tok, 4 + extra);
}
if let Some(arrow) = try_consume_labeled_dash_arrow(&remaining) {
let len = arrow.chars().count();
return (arrow, len);
}
if remaining.starts_with("-.-") {
let base = if remaining.starts_with("-.->") { 4 } else { 3 };
let (label_part, extra) = try_consume_pipe_label(&remaining[base..]);
let tok = format!("{}{label_part}", &remaining[..base]);
return (tok, base + extra);
}
if remaining.starts_with("==") {
let mut len = 0;
for ch in remaining.chars() {
if ch == '=' {
len += 1;
} else {
break;
}
}
let has_arrow = remaining[len..].starts_with('>');
if has_arrow { len += 1; }
let (label_part, extra) = try_consume_pipe_label(&remaining[len..]);
let tok = format!("{}{label_part}", &remaining[..len]);
return (tok, len + extra);
}
if remaining.starts_with("--o") {
return ("--o".to_string(), 3);
}
if remaining.starts_with("--x") {
return ("--x".to_string(), 3);
}
if let Some(rest) = remaining.strip_prefix("-->") {
let (label_part, extra) = try_consume_pipe_label(rest);
let tok = format!("-->{label_part}");
return (tok, 3 + extra);
}
if let Some(rest) = remaining.strip_prefix("---") {
let (label_part, extra) = try_consume_pipe_label(rest);
let tok = format!("---{label_part}");
return (tok, 3 + extra);
}
(remaining[..2].to_string(), 2)
}
fn try_consume_labeled_dash_arrow(s: &str) -> Option<String> {
if !s.starts_with("-- ") {
return None;
}
let rest = &s[3..];
rest.find("-->").map(|end| {
let full_len = 3 + end + 3; s[..full_len].to_string()
})
}
fn try_consume_pipe_label(s: &str) -> (String, usize) {
if let Some(inner) = s.strip_prefix('|')
&& let Some(end) = inner.find('|')
{
let portion = &s[..end + 2]; return (portion.to_string(), end + 2);
}
(String::new(), 0)
}
fn extract_arrow_label(arrow: &str) -> Option<String> {
if let Some(start) = arrow.find('|')
&& let Some(end) = arrow[start + 1..].find('|')
{
let label = arrow[start + 1..start + 1 + end].trim().to_string();
if !label.is_empty() {
return Some(label);
}
}
if arrow.starts_with("-- ")
&& let Some(end) = arrow.rfind("-->")
{
let label = arrow[3..end].trim().to_string();
if !label.is_empty() {
return Some(label);
}
}
None
}
pub(crate) fn parse_node_definition(token: &str) -> Option<Node> {
let token = token.trim();
if token.is_empty() {
return None;
}
if token.starts_with('>') && token.ends_with(']') {
if let Some(pos) = token.find('>') {
let id = token[..pos].trim().to_string();
if !id.is_empty() {
let inner = token[pos + 1..token.len() - 1].trim().to_string();
let label = strip_html_breaks(inner);
return Some(Node::new(id, label, NodeShape::Asymmetric));
}
}
}
let shape_start = token.find(['[', '{', '(', '>']);
let (id, label, shape) = if let Some(pos) = shape_start {
let id = token[..pos].trim().to_string();
let rest = &token[pos..];
if rest.starts_with("(((") && rest.ends_with(")))") {
let inner = rest[3..rest.len() - 3].trim().to_string();
(id, inner, NodeShape::DoubleCircle)
}
else if rest.starts_with("([") && rest.ends_with("])") {
let inner = rest[2..rest.len() - 2].trim().to_string();
(id, inner, NodeShape::Stadium)
}
else if rest.starts_with("[(") && rest.ends_with(")]") {
let inner = rest[2..rest.len() - 2].trim().to_string();
(id, inner, NodeShape::Cylinder)
}
else if rest.starts_with("[[") && rest.ends_with("]]") {
let inner = rest[2..rest.len() - 2].trim().to_string();
(id, inner, NodeShape::Subroutine)
}
else if rest.starts_with("[/") && rest.ends_with("/]") {
let inner = rest[2..rest.len() - 2].trim().to_string();
(id, inner, NodeShape::Parallelogram)
}
else if rest.starts_with("[/") && rest.ends_with("\\]") {
let inner = rest[2..rest.len() - 2].trim().to_string();
(id, inner, NodeShape::Trapezoid)
}
else if rest.starts_with("{{") && rest.ends_with("}}") {
let inner = rest[2..rest.len() - 2].trim().to_string();
(id, inner, NodeShape::Hexagon)
}
else if rest.starts_with("((") && rest.ends_with("))") {
let inner = rest[2..rest.len() - 2].trim().to_string();
(id, inner, NodeShape::Circle)
}
else if rest.starts_with('{') && rest.ends_with('}') {
let inner = rest[1..rest.len() - 1].trim().to_string();
(id, inner, NodeShape::Diamond)
}
else if rest.starts_with('[') && rest.ends_with(']') {
let inner = rest[1..rest.len() - 1].trim().to_string();
(id, inner, NodeShape::Rectangle)
}
else if rest.starts_with('(') && rest.ends_with(')') {
let inner = rest[1..rest.len() - 1].trim().to_string();
(id, inner, NodeShape::Rounded)
}
else if rest.starts_with('>') && rest.ends_with(']') {
let inner = rest[1..rest.len() - 1].trim().to_string();
(id, inner, NodeShape::Asymmetric)
}
else {
let id = token.to_string();
(id.clone(), id, NodeShape::Rectangle)
}
} else {
(token.to_string(), token.to_string(), NodeShape::Rectangle)
};
if id.is_empty() {
return None;
}
let label = strip_html_breaks(label);
Some(Node::new(id, label, shape))
}
fn strip_html_breaks(s: String) -> String {
s.replace("<br/>", " ")
.replace("<br>", " ")
.replace("<br />", " ")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{EdgeEndpoint, EdgeStyle, NodeShape};
#[test]
fn parse_simple_lr() {
let g = parse("graph LR\nA-->B-->C").unwrap();
assert_eq!(g.direction, Direction::LeftToRight);
assert!(g.has_node("A"));
assert!(g.has_node("B"));
assert!(g.has_node("C"));
assert_eq!(g.edges.len(), 2);
}
#[test]
fn parse_semicolons() {
let g = parse("graph LR; A-->B; B-->C").unwrap();
assert_eq!(g.edges.len(), 2);
}
#[test]
fn parse_labeled_nodes() {
let g = parse("graph LR\nA[Start] --> B[End]").unwrap();
assert_eq!(g.node("A").unwrap().label, "Start");
assert_eq!(g.node("B").unwrap().label, "End");
}
#[test]
fn parse_diamond_node() {
let g = parse("graph LR\nA{Decision}").unwrap();
assert_eq!(g.node("A").unwrap().shape, NodeShape::Diamond);
assert_eq!(g.node("A").unwrap().label, "Decision");
}
#[test]
fn parse_circle_node() {
let g = parse("graph LR\nA((Circle))").unwrap();
assert_eq!(g.node("A").unwrap().shape, NodeShape::Circle);
}
#[test]
fn parse_rounded_node() {
let g = parse("graph LR\nA(Rounded)").unwrap();
assert_eq!(g.node("A").unwrap().shape, NodeShape::Rounded);
}
#[test]
fn parse_edge_label_pipe() {
let g = parse("graph LR\nA -->|yes| B").unwrap();
assert_eq!(g.edges[0].label.as_deref(), Some("yes"));
}
#[test]
fn parse_edge_label_dash() {
let g = parse("graph LR\nA -- hello --> B").unwrap();
assert_eq!(g.edges[0].label.as_deref(), Some("hello"));
}
#[test]
fn parse_flowchart_keyword() {
let g = parse("flowchart TD\nA-->B").unwrap();
assert_eq!(g.direction, Direction::TopToBottom);
}
#[test]
fn bad_direction_returns_error() {
assert!(parse("graph XY\nA-->B").is_err());
}
#[test]
fn no_header_returns_error() {
assert!(parse("A-->B").is_err());
}
#[test]
fn parse_subgraph_basic() {
let src = "graph LR\nsubgraph Supervisor\nF[Factory] --> W[Worker]\nend";
let g = parse(src).unwrap();
assert!(g.has_node("F"), "missing F");
assert!(g.has_node("W"), "missing W");
assert_eq!(g.subgraphs.len(), 1);
assert_eq!(g.subgraphs[0].id, "Supervisor");
assert_eq!(g.subgraphs[0].label, "Supervisor");
assert!(g.subgraphs[0].node_ids.contains(&"F".to_string()));
assert!(g.subgraphs[0].node_ids.contains(&"W".to_string()));
}
#[test]
fn parse_subgraph_with_direction() {
let src = "graph LR\nsubgraph S\ndirection TB\nA-->B\nend";
let g = parse(src).unwrap();
assert_eq!(g.subgraphs[0].direction, Some(Direction::TopToBottom));
}
#[test]
fn parse_nested_subgraphs() {
let src = "graph TD\nsubgraph Outer\nsubgraph Inner\nA[A]\nend\nB[B]\nend";
let g = parse(src).unwrap();
assert!(g.find_subgraph("Outer").is_some());
assert!(g.find_subgraph("Inner").is_some());
let outer = g.find_subgraph("Outer").unwrap();
assert!(outer.subgraph_ids.contains(&"Inner".to_string()));
let inner = g.find_subgraph("Inner").unwrap();
assert!(inner.node_ids.contains(&"A".to_string()));
assert!(outer.node_ids.contains(&"B".to_string()));
}
#[test]
fn parse_subgraph_edge_crossing_boundary() {
let src = "graph LR\nsubgraph S\nF[Factory] --> W[Worker]\nend\nW --> HB[Heartbeat]";
let g = parse(src).unwrap();
assert!(g.has_node("F"));
assert!(g.has_node("W"));
assert!(g.has_node("HB"));
assert!(g.edges.iter().any(|e| e.from == "W" && e.to == "HB"));
let s = g.find_subgraph("S").unwrap();
assert!(!s.node_ids.contains(&"HB".to_string()));
}
#[test]
fn node_to_subgraph_map() {
let src = "graph LR\nsubgraph S\nA-->B\nend\nC-->D";
let g = parse(src).unwrap();
let map = g.node_to_subgraph();
assert_eq!(map.get("A").map(String::as_str), Some("S"));
assert_eq!(map.get("B").map(String::as_str), Some("S"));
assert!(!map.contains_key("C"));
assert!(!map.contains_key("D"));
}
#[test]
fn parse_stadium_node() {
let g = parse("graph LR\nA([Stadium])").unwrap();
assert_eq!(g.node("A").unwrap().shape, NodeShape::Stadium);
assert_eq!(g.node("A").unwrap().label, "Stadium");
}
#[test]
fn parse_subroutine_node() {
let g = parse("graph LR\nA[[Sub]]").unwrap();
assert_eq!(g.node("A").unwrap().shape, NodeShape::Subroutine);
assert_eq!(g.node("A").unwrap().label, "Sub");
}
#[test]
fn parse_cylinder_node() {
let g = parse("graph LR\nA[(DB)]").unwrap();
assert_eq!(g.node("A").unwrap().shape, NodeShape::Cylinder);
assert_eq!(g.node("A").unwrap().label, "DB");
}
#[test]
fn parse_hexagon_node() {
let g = parse("graph LR\nA{{Hex}}").unwrap();
assert_eq!(g.node("A").unwrap().shape, NodeShape::Hexagon);
assert_eq!(g.node("A").unwrap().label, "Hex");
}
#[test]
fn parse_asymmetric_node() {
let g = parse("graph LR\nA>Flag]").unwrap();
assert_eq!(g.node("A").unwrap().shape, NodeShape::Asymmetric);
assert_eq!(g.node("A").unwrap().label, "Flag");
}
#[test]
fn parse_parallelogram_node() {
let g = parse("graph LR\nA[/Lean/]").unwrap();
assert_eq!(g.node("A").unwrap().shape, NodeShape::Parallelogram);
assert_eq!(g.node("A").unwrap().label, "Lean");
}
#[test]
fn parse_trapezoid_node() {
let g = parse("graph LR\nA[/Trap\\]").unwrap();
assert_eq!(g.node("A").unwrap().shape, NodeShape::Trapezoid);
assert_eq!(g.node("A").unwrap().label, "Trap");
}
#[test]
fn parse_double_circle_node() {
let g = parse("graph LR\nA(((Dbl)))").unwrap();
assert_eq!(g.node("A").unwrap().shape, NodeShape::DoubleCircle);
assert_eq!(g.node("A").unwrap().label, "Dbl");
}
#[test]
fn triple_paren_beats_double_paren() {
let g = parse("graph LR\nA(((X)))").unwrap();
assert_eq!(g.node("A").unwrap().shape, NodeShape::DoubleCircle);
}
#[test]
fn double_bracket_beats_single_bracket() {
let g = parse("graph LR\nA[[Y]]").unwrap();
assert_eq!(g.node("A").unwrap().shape, NodeShape::Subroutine);
}
#[test]
fn parse_dotted_edge_style() {
let g = parse("graph LR\nA-.->B").unwrap();
assert_eq!(g.edges[0].style, EdgeStyle::Dotted);
assert_eq!(g.edges[0].end, EdgeEndpoint::Arrow);
}
#[test]
fn parse_thick_edge_style() {
let g = parse("graph LR\nA==>B").unwrap();
assert_eq!(g.edges[0].style, EdgeStyle::Thick);
assert_eq!(g.edges[0].end, EdgeEndpoint::Arrow);
}
#[test]
fn parse_plain_line_no_arrow() {
let g = parse("graph LR\nA---B").unwrap();
assert_eq!(g.edges[0].style, EdgeStyle::Solid);
assert_eq!(g.edges[0].end, EdgeEndpoint::None);
assert_eq!(g.edges[0].start, EdgeEndpoint::None);
}
#[test]
fn parse_bidirectional_edge() {
let g = parse("graph LR\nA<-->B").unwrap();
assert_eq!(g.edges[0].style, EdgeStyle::Solid);
assert_eq!(g.edges[0].start, EdgeEndpoint::Arrow);
assert_eq!(g.edges[0].end, EdgeEndpoint::Arrow);
}
#[test]
fn parse_circle_endpoint() {
let g = parse("graph LR\nA--oB").unwrap();
assert_eq!(g.edges[0].end, EdgeEndpoint::Circle);
}
#[test]
fn parse_cross_endpoint() {
let g = parse("graph LR\nA--xB").unwrap();
assert_eq!(g.edges[0].end, EdgeEndpoint::Cross);
}
}