use std::collections::{HashMap, HashSet};
use super::flows::{Flow, FlowError};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DiagramNodeKind {
Agent,
Work,
Fork,
Join,
Either,
Terminal,
}
impl DiagramNodeKind {
fn label_suffix(&self) -> &'static str {
match self {
Self::Agent => "agent",
Self::Work => "work",
Self::Fork => "fork",
Self::Join => "join",
Self::Either => "either",
Self::Terminal => "terminal",
}
}
}
#[derive(Debug, Clone)]
pub struct DiagramNode {
pub id: String,
pub kind: DiagramNodeKind,
}
#[derive(Debug, Clone)]
pub struct DiagramEdge {
pub from: String,
pub to: String,
pub label: &'static str,
}
#[derive(Debug, Clone)]
pub struct FlowGraphDiagram {
entry: String,
nodes: Vec<DiagramNode>,
edges: Vec<DiagramEdge>,
}
impl FlowGraphDiagram {
pub fn for_flow<F: Flow>() -> Result<Self, FlowError> {
let graph = F::build()?.with_entry(F::node_id())?;
Ok(graph.diagram())
}
pub(crate) fn new(entry: String, nodes: Vec<DiagramNode>, edges: Vec<DiagramEdge>) -> Self {
Self {
entry,
nodes,
edges,
}
}
pub fn entry(&self) -> &str {
&self.entry
}
pub fn nodes(&self) -> &[DiagramNode] {
&self.nodes
}
pub fn edges(&self) -> &[DiagramEdge] {
&self.edges
}
pub fn mermaid(&self) -> String {
let mut out = String::from("flowchart LR\n");
out.push_str(" _start(( ))\n");
for node in &self.nodes {
let safe_id = mermaid_id(&node.id);
let decl = match node.kind {
DiagramNodeKind::Agent | DiagramNodeKind::Work => {
format!(
" {}[\"{} ({})\"]",
safe_id,
node.id,
node.kind.label_suffix()
)
}
DiagramNodeKind::Fork | DiagramNodeKind::Either => {
format!(
" {}{{\"{} ({})\"}}",
safe_id,
node.id,
node.kind.label_suffix()
)
}
DiagramNodeKind::Join => {
format!(
" {}([\"{} (join)\"])",
safe_id, node.id
)
}
DiagramNodeKind::Terminal => {
format!(" {}([\"{} ◉\"])", safe_id, node.id)
}
};
out.push_str(&decl);
out.push('\n');
}
out.push_str(&format!(" _start --> {}\n", mermaid_id(&self.entry)));
for edge in &self.edges {
out.push_str(&format!(
" {} -->|{}| {}\n",
mermaid_id(&edge.from),
edge.label,
mermaid_id(&edge.to)
));
}
out
}
pub fn dot(&self) -> String {
let mut out = String::from("digraph {\n rankdir=LR;\n");
out.push_str(
" _start [label=\"\" shape=circle style=filled fillcolor=black width=0.3];\n",
);
for node in &self.nodes {
let safe_id = dot_id(&node.id);
let attrs = match node.kind {
DiagramNodeKind::Agent | DiagramNodeKind::Work => format!(
"label=\"{}\\n({})\" shape=box style=rounded",
node.id,
node.kind.label_suffix()
),
DiagramNodeKind::Fork | DiagramNodeKind::Either => format!(
"label=\"{}\\n({})\" shape=diamond",
node.id,
node.kind.label_suffix()
),
DiagramNodeKind::Join => {
format!("label=\"{}\\n(join)\" shape=ellipse", node.id)
}
DiagramNodeKind::Terminal => {
format!("label=\"{}\" shape=doublecircle", node.id)
}
};
out.push_str(&format!(" {} [{}];\n", safe_id, attrs));
}
out.push_str(&format!(" _start -> {};\n", dot_id(&self.entry)));
for edge in &self.edges {
out.push_str(&format!(
" {} -> {} [label=\"{}\"];\n",
dot_id(&edge.from),
dot_id(&edge.to),
edge.label,
));
}
out.push_str("}\n");
out
}
#[cfg(feature = "diagram-text")]
pub fn render_text(&self) -> Result<String, mermaid_text::Error> {
mermaid_text::render(&self.mermaid())
}
#[cfg(feature = "diagram-text")]
pub fn render_ascii(&self) -> Result<String, mermaid_text::Error> {
mermaid_text::render_ascii(&self.mermaid())
}
pub fn render_tree(&self) -> String {
let mut adj: HashMap<&str, Vec<(&str, &str)>> = HashMap::new();
for node in &self.nodes {
adj.entry(node.id.as_str()).or_default();
}
for edge in &self.edges {
adj.entry(edge.from.as_str())
.or_default()
.push((edge.label, edge.to.as_str()));
}
for succs in adj.values_mut() {
succs.sort_by_key(|(_, to)| *to);
}
let node_kind: HashMap<&str, &DiagramNodeKind> =
self.nodes.iter().map(|n| (n.id.as_str(), &n.kind)).collect();
let mut visited: HashSet<String> = HashSet::new();
let mut out = String::new();
tree_write_node(
&self.entry,
"",
true,
true,
None,
&mut visited,
&adj,
&node_kind,
&mut out,
);
out
}
}
fn mermaid_id(id: &str) -> String {
id.chars()
.map(|c| if c.is_alphanumeric() || c == '_' { c } else { '_' })
.collect()
}
fn dot_id(id: &str) -> String {
format!("\"{}\"", id.replace('"', "\\\""))
}
#[allow(clippy::too_many_arguments)]
fn tree_write_node(
id: &str,
prefix: &str,
is_root: bool,
is_last: bool,
edge_label: Option<&str>,
visited: &mut HashSet<String>,
adj: &HashMap<&str, Vec<(&str, &str)>>,
node_kind: &HashMap<&str, &DiagramNodeKind>,
out: &mut String,
) {
let repeated = visited.contains(id);
let kind_tag = match node_kind.get(id).copied() {
Some(DiagramNodeKind::Agent) => " (agent)",
Some(DiagramNodeKind::Work) => " (work)",
Some(DiagramNodeKind::Fork) => " (fork)",
Some(DiagramNodeKind::Join) => " (join)",
Some(DiagramNodeKind::Either) => " (either)",
Some(DiagramNodeKind::Terminal) => " ◉",
None => "",
};
let display = if repeated {
format!("{}{} ↩", id, kind_tag)
} else {
format!("{}{}", id, kind_tag)
};
if is_root {
out.push_str(&format!("● {}\n", display));
} else {
let connector = if is_last { "└── " } else { "├── " };
let edge_part = match edge_label {
Some(l) if !l.is_empty() => format!("[{}] ", l),
_ => String::new(),
};
out.push_str(&format!("{}{}{}{}\n", prefix, connector, edge_part, display));
}
if repeated {
return;
}
visited.insert(id.to_string());
let succs = match adj.get(id) {
Some(v) if !v.is_empty() => v,
_ => return,
};
let child_prefix = if is_root {
" ".to_string()
} else if is_last {
format!("{} ", prefix)
} else {
format!("{}│ ", prefix)
};
for (i, (label, to)) in succs.iter().enumerate() {
let is_last_child = i == succs.len() - 1;
tree_write_node(
to,
&child_prefix,
false,
is_last_child,
Some(label),
visited,
adj,
node_kind,
out,
);
}
}
pub(crate) struct NodeDesc {
pub id: String,
pub kind: DiagramNodeKind,
pub succs: Vec<(String, &'static str)>,
}
pub(crate) fn build_diagram(entry: String, descs: Vec<NodeDesc>) -> FlowGraphDiagram {
let defined_ids: HashSet<&str> = descs.iter().map(|d| d.id.as_str()).collect();
let mut nodes: Vec<DiagramNode> = descs
.iter()
.map(|d| DiagramNode {
id: d.id.clone(),
kind: d.kind.clone(),
})
.collect();
let mut edges: Vec<DiagramEdge> = Vec::new();
let mut terminal_ids: HashSet<String> = HashSet::new();
for desc in &descs {
for (to, label) in &desc.succs {
edges.push(DiagramEdge {
from: desc.id.clone(),
to: to.clone(),
label,
});
if !defined_ids.contains(to.as_str()) {
terminal_ids.insert(to.clone());
}
}
}
for id in terminal_ids {
nodes.push(DiagramNode {
id,
kind: DiagramNodeKind::Terminal,
});
}
FlowGraphDiagram::new(entry, nodes, edges)
}