use crate::config::{MermaidOptions, MermaidThemeMode};
use crate::pipeline::Transformer;
use crate::visit::{NodeAction, Visitor, walk_root};
use dmc_diagnostic::Code;
use dmc_diagnostic::metadata::SourceMeta;
use dmc_parser::ast::*;
use duck_diagnostic::{Diagnostic, Label, diag};
use std::collections::{BTreeMap, HashMap};
use std::io::Write;
use std::path::PathBuf;
use std::process::{Command, Stdio};
use std::sync::{Mutex, OnceLock};
pub struct Mermaid {
opts: MermaidOptions,
cache: Mutex<HashMap<u64, String>>,
}
static MMDC_AVAILABLE: OnceLock<bool> = OnceLock::new();
impl Default for Mermaid {
fn default() -> Self {
Self::from_options(MermaidOptions::default())
}
}
impl Mermaid {
pub fn from_options(opts: MermaidOptions) -> Self {
Self { opts, cache: Mutex::new(HashMap::new()) }
}
pub fn with_output(p: impl Into<PathBuf>) -> Self {
Self::from_options(MermaidOptions { output_dir: Some(p.into()), ..Default::default() })
}
fn mmdc_available() -> bool {
*MMDC_AVAILABLE.get_or_init(|| {
Command::new("mmdc")
.arg("--version")
.stdout(Stdio::null())
.stderr(Stdio::null())
.status()
.map(|s| s.success())
.unwrap_or(false)
})
}
fn theme_renders(&self) -> Vec<(String, String)> {
match &self.opts.theme {
MermaidThemeMode::Single(name) => vec![("chartSvg".to_string(), name.clone())],
MermaidThemeMode::Multi(map) => map.iter().map(|(k, v)| (format!("{k}Svg"), v.clone())).collect(),
}
}
fn render_cached(&self, source: &str, theme: &str) -> Result<String, String> {
let key = {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
theme.hash(&mut hasher);
source.hash(&mut hasher);
hasher.finish()
};
if let Some(svg) = self.cache.lock().unwrap().get(&key) {
return Ok(svg.clone());
}
if let Some(dir) = &self.opts.output_dir {
let path = dir.join(format!("{key}.svg"));
match std::fs::read_to_string(&path) {
Ok(s) => return Ok(s),
Err(e) => {
if e.kind() != std::io::ErrorKind::NotFound {
return Err(e.to_string());
}
},
}
}
let svg = self.render_mmdc(source, theme)?;
self.cache.lock().unwrap().insert(key, svg.clone());
if let Some(dir) = &self.opts.output_dir {
let _ = std::fs::create_dir_all(dir);
let path = dir.join(format!("{key}.svg"));
let _ = std::fs::write(&path, &svg).map_err(|e| e.to_string());
}
Ok(svg)
}
fn build_mermaid_config(&self) -> serde_json::Value {
let html_labels = self.opts.html_labels.unwrap_or(false);
let mut base = serde_json::json!({
"htmlLabels": html_labels,
"flowchart": {
"htmlLabels": html_labels,
"useMaxWidth": true,
"nodeSpacing": 50,
"rankSpacing": 60,
"padding": 20,
}
});
if let Ok(serde_json::Value::Object(mut user)) = serde_json::to_value(&self.opts) {
for k in ["theme", "responsiveSvg", "centerLabels", "outputDir", "puppeteerConfigFile", "backgroundColor"] {
user.remove(k);
}
if !user.is_empty() {
shallow_merge(&mut base, &serde_json::Value::Object(user));
}
}
base
}
fn render_mmdc(&self, source: &str, theme: &str) -> Result<String, String> {
let cfg_json = self.build_mermaid_config();
let cfg_str = cfg_json.to_string();
let mut cfg_file = tempfile::Builder::new()
.prefix("dmc-mermaid-config-")
.suffix(".json")
.tempfile()
.map_err(|e| format!("config temp file failed: {e}"))?;
cfg_file.write_all(cfg_str.as_bytes()).map_err(|e| format!("config write failed: {e}"))?;
cfg_file.flush().map_err(|e| format!("config write failed: {e}"))?;
let bg = self.opts.background_color.as_deref().unwrap_or("transparent");
let cfg_path_str = cfg_file.path().to_str().unwrap_or("").to_string();
let mut args: Vec<String> = vec![
"--input".into(),
"-".into(),
"--output".into(),
"-".into(),
"--outputFormat".into(),
"svg".into(),
"--theme".into(),
theme.to_string(),
"--backgroundColor".into(),
bg.to_string(),
"--configFile".into(),
cfg_path_str,
"--quiet".into(),
];
if let Some(p) = &self.opts.puppeteer_config_file {
args.push("--puppeteerConfigFile".into());
args.push(p.to_string_lossy().into_owned());
}
let mut child = Command::new("mmdc")
.args(&args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| format!("spawn failed: {e}"))?;
child
.stdin
.as_mut()
.ok_or_else(|| "no stdin handle".to_string())?
.write_all(source.as_bytes())
.map_err(|e| format!("stdin write failed: {e}"))?;
let out = child.wait_with_output().map_err(|e| format!("wait failed: {e}"))?;
if !out.status.success() {
let err = String::from_utf8_lossy(&out.stderr).into_owned();
return Err(if err.is_empty() { format!("exit {}", out.status) } else { err });
}
let svg = String::from_utf8(out.stdout).map_err(|e| format!("non-utf8 svg: {e}"))?;
Ok(self.post_process(&svg))
}
fn post_process(&self, svg: &str) -> String {
let mut out = svg.to_string();
if self.opts.responsive_svg.unwrap_or(true) {
out = make_responsive(&out);
}
if self.opts.center_labels.unwrap_or(true) {
out = center_labels(&out);
}
out
}
fn render_all(
&self,
chart: &str,
span: &duck_diagnostic::Span,
pending: &mut Vec<Diagnostic<Code>>,
) -> Option<BTreeMap<String, String>> {
let mut out = BTreeMap::new();
for (attr, theme) in self.theme_renders() {
match self.render_cached(chart, &theme) {
Ok(s) => {
out.insert(attr, s);
},
Err(err) => {
pending.push(
diag!(Code::MermaidRenderFailed, format!("mermaid ({theme}): mmdc failed - {}", err.trim()))
.with_label(Label::primary(span.clone(), Some("for this mermaid block".into()))),
);
return None;
},
}
}
Some(out)
}
}
fn shallow_merge(base: &mut serde_json::Value, extra: &serde_json::Value) {
use serde_json::Value;
if let (Value::Object(b), Value::Object(e)) = (base, extra) {
for (k, v) in e {
b.insert(k.clone(), v.clone());
}
}
}
fn make_responsive(svg: &str) -> String {
if let Some(idx) = svg.find("width=\"")
&& let Some(end) = svg[idx + "width=\"".len()..].find('"')
{
let head = &svg[..idx];
let tail = &svg[idx + "width=\"".len() + end + 1..];
return format!("{head}width=\"100%\"{tail}");
}
svg.to_string()
}
fn center_labels(svg: &str) -> String {
let mut out = svg.replace("<text y=\"-10.1\"", "<text y=\"-10.1\" text-anchor=\"middle\"");
out = out.replace(
"<tspan class=\"text-outer-tspan row\" x=\"0\"",
"<tspan class=\"text-outer-tspan row\" x=\"0\" text-anchor=\"middle\"",
);
out
}
impl Transformer for Mermaid {
fn name(&self) -> &str {
"mermaid"
}
fn transform(
&self,
doc: &mut Document,
_meta: &SourceMeta,
diag_engine: &mut duck_diagnostic::DiagnosticEngine<Code>,
) {
if !Self::mmdc_available() {
diag_engine.emit(diag!(
Code::MmdcUnavailable,
"mermaid: `mmdc` is not on PATH; mermaid blocks left as code (install with `npm i -g @mermaid-js/mermaid-cli`)"
));
return;
}
let mut v = Apply { pending: Vec::new(), mermaid: self };
walk_root(&mut doc.children, &mut v);
for d in v.pending.drain(..) {
diag_engine.emit(d);
}
}
}
struct Apply<'a> {
pending: Vec<Diagnostic<Code>>,
mermaid: &'a Mermaid,
}
impl<'a> Apply<'a> {
fn jsx_attrs_with_svgs(
chart: String,
svgs: BTreeMap<String, String>,
span: &duck_diagnostic::Span,
extra: Vec<JsxAttr>,
) -> Vec<JsxAttr> {
let svg_keys: std::collections::HashSet<&str> = svgs.keys().map(String::as_str).collect();
let mut out: Vec<JsxAttr> =
extra.into_iter().filter(|a| a.name != "chart" && !svg_keys.contains(a.name.as_str())).collect();
out.push(JsxAttr { name: "chart".into(), value: JsxAttrValue::String(chart), span: span.clone() });
for (k, v) in svgs {
out.push(JsxAttr { name: k, value: JsxAttrValue::String(v), span: span.clone() });
}
out
}
}
impl<'a> Visitor for Apply<'a> {
fn visit_node(&mut self, node: &mut Node) -> NodeAction {
match node {
Node::CodeBlock(cb) if cb.lang.as_deref() == Some("mermaid") => {
let span = cb.span.clone();
let chart = cb.value.clone();
let Some(svgs) = self.mermaid.render_all(&chart, &span, &mut self.pending) else {
return NodeAction::Keep;
};
let attrs = Apply::jsx_attrs_with_svgs(chart, svgs, &span, Vec::new());
*node = Node::JsxSelfClosing(JsxSelfClosing { name: "MermaidDiagram".into(), attrs, span });
NodeAction::KeepSkipChildren
},
Node::JsxSelfClosing(jsc) if jsc.name == "MermaidDiagram" => {
let span = jsc.span.clone();
let Some(chart) = extract_chart_attr(&jsc.attrs) else { return NodeAction::Keep };
let Some(svgs) = self.mermaid.render_all(&chart, &span, &mut self.pending) else {
return NodeAction::Keep;
};
let extra = std::mem::take(&mut jsc.attrs);
jsc.attrs = Apply::jsx_attrs_with_svgs(chart, svgs, &span, extra);
NodeAction::KeepSkipChildren
},
Node::JsxElement(je) if je.name == "MermaidDiagram" => {
let span = je.span.clone();
let Some(chart) = extract_chart_attr(&je.attrs) else { return NodeAction::Keep };
let Some(svgs) = self.mermaid.render_all(&chart, &span, &mut self.pending) else {
return NodeAction::Keep;
};
let extra = std::mem::take(&mut je.attrs);
je.attrs = Apply::jsx_attrs_with_svgs(chart, svgs, &span, extra);
NodeAction::KeepSkipChildren
},
_ => NodeAction::Keep,
}
}
}
fn extract_chart_attr(attrs: &[JsxAttr]) -> Option<String> {
let attr = attrs.iter().find(|a| a.name == "chart")?;
match &attr.value {
JsxAttrValue::String(s) => Some(s.clone()),
JsxAttrValue::Expression(e) => {
let t = e.trim();
if (t.starts_with('`') && t.ends_with('`'))
|| (t.starts_with('"') && t.ends_with('"'))
|| (t.starts_with('\'') && t.ends_with('\''))
{
Some(t[1..t.len() - 1].to_string())
} else {
None
}
},
JsxAttrValue::Boolean | JsxAttrValue::Spread(_) => None,
}
}