use std::collections::HashMap;
use portgraph::render::{EdgeStyle, NodeStyle, PortStyle, PresentationStyle};
use portgraph::{LinkView, MultiPortGraph, NodeIndex, PortIndex, PortView};
use crate::core::HugrNode;
use crate::hugr::internal::HugrInternals;
use crate::ops::{NamedOp, OpType};
use crate::types::EdgeKind;
use crate::{Hugr, HugrView, Node};
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct MermaidFormatter<'h, H: HugrInternals + ?Sized = Hugr> {
hugr: &'h H,
node_labels: NodeLabel<H::Node>,
port_offsets_in_edges: bool,
type_labels_in_edges: bool,
entrypoint: Option<H::Node>,
}
impl<'h, H: HugrInternals + ?Sized> MermaidFormatter<'h, H> {
pub fn new(hugr: &'h H) -> Self {
Self {
hugr,
node_labels: NodeLabel::Numeric,
port_offsets_in_edges: true,
type_labels_in_edges: true,
entrypoint: None,
}
}
pub fn entrypoint(&self) -> Option<H::Node> {
self.entrypoint
}
pub fn node_labels(&self) -> &NodeLabel<H::Node> {
&self.node_labels
}
pub fn port_offsets(&self) -> bool {
self.port_offsets_in_edges
}
pub fn type_labels(&self) -> bool {
self.type_labels_in_edges
}
pub fn with_node_labels(mut self, node_labels: NodeLabel<H::Node>) -> Self {
self.node_labels = node_labels;
self
}
pub fn with_port_offsets(mut self, show: bool) -> Self {
self.port_offsets_in_edges = show;
self
}
pub fn with_type_labels(mut self, show: bool) -> Self {
self.type_labels_in_edges = show;
self
}
pub fn with_entrypoint(mut self, entrypoint: impl Into<Option<H::Node>>) -> Self {
self.entrypoint = entrypoint.into();
self
}
pub fn finish(self) -> String
where
H: HugrView,
{
self.hugr.mermaid_string_with_formatter(self)
}
pub(crate) fn with_hugr<NewH: HugrInternals<Node = H::Node>>(
self,
hugr: &NewH,
) -> MermaidFormatter<'_, NewH> {
let MermaidFormatter {
hugr: _,
node_labels,
port_offsets_in_edges,
type_labels_in_edges,
entrypoint,
} = self;
MermaidFormatter {
hugr,
node_labels,
port_offsets_in_edges,
type_labels_in_edges,
entrypoint,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum UnsupportedRenderConfig {
#[error("Custom node labels are not supported in the `RenderConfig` struct")]
CustomNodeLabels,
}
macro_rules! impl_mermaid_formatter_from {
($t:ty, $($lifetime:tt)?) => {
impl<'h, $($lifetime,)? H: HugrView> From<MermaidFormatter<'h, $t>> for MermaidFormatter<'h, H> {
fn from(value: MermaidFormatter<'h, $t>) -> Self {
let MermaidFormatter {
hugr,
node_labels,
port_offsets_in_edges,
type_labels_in_edges,
entrypoint,
} = value;
MermaidFormatter {
hugr,
node_labels,
port_offsets_in_edges,
type_labels_in_edges,
entrypoint,
}
}
}
};
}
impl_mermaid_formatter_from!(&'hh H, 'hh);
impl_mermaid_formatter_from!(&'hh mut H, 'hh);
impl_mermaid_formatter_from!(std::rc::Rc<H>,);
impl_mermaid_formatter_from!(std::sync::Arc<H>,);
impl_mermaid_formatter_from!(Box<H>,);
impl<'h, H: HugrView + ToOwned> From<MermaidFormatter<'h, std::borrow::Cow<'_, H>>>
for MermaidFormatter<'h, H>
{
fn from(value: MermaidFormatter<'h, std::borrow::Cow<'_, H>>) -> Self {
let MermaidFormatter {
hugr,
node_labels,
port_offsets_in_edges,
type_labels_in_edges,
entrypoint,
} = value;
MermaidFormatter {
hugr,
node_labels,
port_offsets_in_edges,
type_labels_in_edges,
entrypoint,
}
}
}
#[derive(Default, Clone, Debug, PartialEq, Eq)]
pub enum NodeLabel<N: HugrNode = Node> {
None,
#[default]
Numeric,
Custom(HashMap<N, String>),
}
pub(in crate::hugr) fn node_style<'a>(
h: &'a Hugr,
formatter: MermaidFormatter<'a>,
) -> Box<dyn FnMut(NodeIndex) -> NodeStyle + 'a> {
fn node_name(h: &Hugr, n: NodeIndex) -> String {
match h.get_optype(n.into()) {
OpType::FuncDecl(f) => format!("FuncDecl: \"{}\"", f.func_name()),
OpType::FuncDefn(f) => format!("FuncDefn: \"{}\"", f.func_name()),
op => op.name().to_string(),
}
}
let mut entrypoint_style = PresentationStyle::default();
entrypoint_style.stroke = Some("#832561".to_string());
entrypoint_style.stroke_width = Some("3px".to_string());
let entrypoint = formatter.entrypoint.map(Node::into_portgraph);
match formatter.node_labels {
NodeLabel::Numeric => Box::new(move |n| {
if Some(n) == entrypoint {
NodeStyle::boxed(format!(
"({ni}) [**{name}**]",
ni = n.index(),
name = node_name(h, n)
))
.with_attrs(entrypoint_style.clone())
} else {
NodeStyle::boxed(format!(
"({ni}) {name}",
ni = n.index(),
name = node_name(h, n)
))
}
}),
NodeLabel::None => Box::new(move |n| {
if Some(n) == entrypoint {
NodeStyle::boxed(format!("[**{name}**]", name = node_name(h, n)))
.with_attrs(entrypoint_style.clone())
} else {
NodeStyle::boxed(node_name(h, n))
}
}),
NodeLabel::Custom(labels) => Box::new(move |n| {
if Some(n) == entrypoint {
NodeStyle::boxed(format!(
"({label}) [**{name}**]",
label = labels.get(&n.into()).unwrap_or(&n.index().to_string()),
name = node_name(h, n)
))
.with_attrs(entrypoint_style.clone())
} else {
NodeStyle::boxed(format!(
"({label}) {name}",
label = labels.get(&n.into()).unwrap_or(&n.index().to_string()),
name = node_name(h, n)
))
}
}),
}
}
pub(in crate::hugr) fn port_style(h: &Hugr) -> Box<dyn FnMut(PortIndex) -> PortStyle + '_> {
let graph = &h.graph;
Box::new(move |port| {
let node = graph.port_node(port).unwrap();
let optype = h.get_optype(node.into());
let offset = graph.port_offset(port).unwrap();
match optype.port_kind(offset).unwrap() {
EdgeKind::Function(pf) => PortStyle::new(html_escape::encode_text(&format!("{pf}"))),
EdgeKind::Const(ty) | EdgeKind::Value(ty) => {
PortStyle::new(html_escape::encode_text(&format!("{ty}")))
}
EdgeKind::StateOrder => {
if graph.port_links(port).count() > 0 {
PortStyle::text("", false)
} else {
PortStyle::Hidden
}
}
_ => PortStyle::text("", true),
}
})
}
#[allow(clippy::type_complexity)]
pub(in crate::hugr) fn edge_style<'a>(
h: &'a Hugr,
config: MermaidFormatter<'_>,
) -> Box<
dyn FnMut(
<MultiPortGraph<u32, u32, u32> as LinkView>::LinkEndpoint,
<MultiPortGraph<u32, u32, u32> as LinkView>::LinkEndpoint,
) -> EdgeStyle
+ 'a,
> {
let graph = &h.graph;
Box::new(move |src, tgt| {
let src_node = graph.port_node(src).unwrap();
let src_optype = h.get_optype(src_node.into());
let src_offset = graph.port_offset(src).unwrap();
let tgt_offset = graph.port_offset(tgt).unwrap();
let port_kind = src_optype.port_kind(src_offset).unwrap();
let style = match port_kind {
EdgeKind::StateOrder => EdgeStyle::Dotted,
EdgeKind::ControlFlow => EdgeStyle::Dashed,
EdgeKind::Const(_) | EdgeKind::Function(_) | EdgeKind::Value(_) => EdgeStyle::Solid,
};
fn type_label(e: EdgeKind) -> Option<String> {
match e {
EdgeKind::Const(ty) | EdgeKind::Value(ty) => Some(format!("{ty}")),
EdgeKind::Function(pf) => Some(format!("{pf}")),
_ => None,
}
}
let label = match (
config.port_offsets_in_edges,
type_label(port_kind).filter(|_| config.type_labels_in_edges),
) {
(true, Some(ty)) => {
format!("{}:{}\n{ty}", src_offset.index(), tgt_offset.index())
}
(true, _) => format!("{}:{}", src_offset.index(), tgt_offset.index()),
(false, Some(ty)) => ty.to_string(),
_ => return style,
};
style.with_label(label)
})
}
#[cfg(test)]
mod tests {
use crate::{NodeIndex, builder::test::simple_dfg_hugr};
use super::*;
#[cfg_attr(miri, ignore)] #[test]
fn test_custom_node_labels() {
let h = simple_dfg_hugr();
let node_labels = h
.nodes()
.map(|n| (n, format!("node_{}", n.index())))
.collect();
let config = h
.mermaid_format()
.with_node_labels(NodeLabel::Custom(node_labels));
insta::assert_snapshot!(h.mermaid_string_with_formatter(config));
}
}