use std::collections::HashSet;
use super::flows::{Flow, FlowError};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DiagramNodeKind {
Agent,
Work,
Fork,
Join,
Either,
Terminal,
}
impl DiagramNodeKind {
fn label_suffix(&self) -> &'static str {
match self {
Self::Agent => "agent",
Self::Work => "work",
Self::Fork => "fork",
Self::Join => "join",
Self::Either => "either",
Self::Terminal => "terminal",
}
}
}
#[derive(Debug, Clone)]
pub struct DiagramNode {
pub id: String,
pub kind: DiagramNodeKind,
}
#[derive(Debug, Clone)]
pub struct DiagramEdge {
pub from: String,
pub to: String,
pub label: &'static str,
}
#[derive(Debug, Clone)]
pub struct FlowGraphDiagram {
entry: String,
nodes: Vec<DiagramNode>,
edges: Vec<DiagramEdge>,
}
impl FlowGraphDiagram {
pub fn for_flow<F: Flow>() -> Result<Self, FlowError> {
let graph = F::build()?.with_entry(F::node_id())?;
Ok(graph.diagram())
}
pub(crate) fn new(entry: String, nodes: Vec<DiagramNode>, edges: Vec<DiagramEdge>) -> Self {
Self {
entry,
nodes,
edges,
}
}
pub fn entry(&self) -> &str {
&self.entry
}
pub fn nodes(&self) -> &[DiagramNode] {
&self.nodes
}
pub fn edges(&self) -> &[DiagramEdge] {
&self.edges
}
pub fn mermaid(&self) -> String {
let mut out = String::from("flowchart LR\n");
out.push_str(" _start(( ))\n");
for node in &self.nodes {
let safe_id = mermaid_id(&node.id);
let decl = match node.kind {
DiagramNodeKind::Agent | DiagramNodeKind::Work => {
format!(
" {}[\"{} ({})\"]",
safe_id,
node.id,
node.kind.label_suffix()
)
}
DiagramNodeKind::Fork | DiagramNodeKind::Either => {
format!(
" {}{{\"{} ({})\"}}",
safe_id,
node.id,
node.kind.label_suffix()
)
}
DiagramNodeKind::Join => {
format!(
" {}([\"{} (join)\"])",
safe_id, node.id
)
}
DiagramNodeKind::Terminal => {
format!(" {}([\"{} ◉\"])", safe_id, node.id)
}
};
out.push_str(&decl);
out.push('\n');
}
out.push_str(&format!(" _start --> {}\n", mermaid_id(&self.entry)));
for edge in &self.edges {
out.push_str(&format!(
" {} -->|{}| {}\n",
mermaid_id(&edge.from),
edge.label,
mermaid_id(&edge.to)
));
}
out
}
pub fn dot(&self) -> String {
let mut out = String::from("digraph {\n rankdir=LR;\n");
out.push_str(
" _start [label=\"\" shape=circle style=filled fillcolor=black width=0.3];\n",
);
for node in &self.nodes {
let safe_id = dot_id(&node.id);
let attrs = match node.kind {
DiagramNodeKind::Agent | DiagramNodeKind::Work => format!(
"label=\"{}\\n({})\" shape=box style=rounded",
node.id,
node.kind.label_suffix()
),
DiagramNodeKind::Fork | DiagramNodeKind::Either => format!(
"label=\"{}\\n({})\" shape=diamond",
node.id,
node.kind.label_suffix()
),
DiagramNodeKind::Join => {
format!("label=\"{}\\n(join)\" shape=ellipse", node.id)
}
DiagramNodeKind::Terminal => {
format!("label=\"{}\" shape=doublecircle", node.id)
}
};
out.push_str(&format!(" {} [{}];\n", safe_id, attrs));
}
out.push_str(&format!(" _start -> {};\n", dot_id(&self.entry)));
for edge in &self.edges {
out.push_str(&format!(
" {} -> {} [label=\"{}\"];\n",
dot_id(&edge.from),
dot_id(&edge.to),
edge.label,
));
}
out.push_str("}\n");
out
}
#[cfg(feature = "diagram-text")]
pub fn render_text(&self) -> Result<String, mermaid_text::Error> {
mermaid_text::render(&self.mermaid())
}
#[cfg(feature = "diagram-text")]
pub fn render_ascii(&self) -> Result<String, mermaid_text::Error> {
mermaid_text::render_ascii(&self.mermaid())
}
}
fn mermaid_id(id: &str) -> String {
id.chars()
.map(|c| if c.is_alphanumeric() || c == '_' { c } else { '_' })
.collect()
}
fn dot_id(id: &str) -> String {
format!("\"{}\"", id.replace('"', "\\\""))
}
pub(crate) struct NodeDesc {
pub id: String,
pub kind: DiagramNodeKind,
pub succs: Vec<(String, &'static str)>,
}
pub(crate) fn build_diagram(entry: String, descs: Vec<NodeDesc>) -> FlowGraphDiagram {
let defined_ids: HashSet<&str> = descs.iter().map(|d| d.id.as_str()).collect();
let mut nodes: Vec<DiagramNode> = descs
.iter()
.map(|d| DiagramNode {
id: d.id.clone(),
kind: d.kind.clone(),
})
.collect();
let mut edges: Vec<DiagramEdge> = Vec::new();
let mut terminal_ids: HashSet<String> = HashSet::new();
for desc in &descs {
for (to, label) in &desc.succs {
edges.push(DiagramEdge {
from: desc.id.clone(),
to: to.clone(),
label,
});
if !defined_ids.contains(to.as_str()) {
terminal_ids.insert(to.clone());
}
}
}
for id in terminal_ids {
nodes.push(DiagramNode {
id,
kind: DiagramNodeKind::Terminal,
});
}
FlowGraphDiagram::new(entry, nodes, edges)
}