Skip to main content

dmc_transform/builtin/
mermaid.rs

1//! Mermaid pre-renderer. See `transformers/mermaid.md` for full docs.
2
3use crate::config::{MermaidOptions, MermaidThemeMode};
4use crate::pipeline::Transformer;
5use crate::visit::{NodeAction, Visitor, walk_root};
6use dmc_diagnostic::Code;
7use dmc_diagnostic::metadata::SourceMeta;
8use dmc_parser::ast::*;
9use duck_diagnostic::{Diagnostic, Label, diag};
10use std::collections::{BTreeMap, HashMap};
11use std::io::Write;
12use std::path::PathBuf;
13use std::process::{Command, Stdio};
14use std::sync::{Mutex, OnceLock};
15
16/// Pre-render mermaid diagrams to inline SVG via the external `mmdc` CLI
17/// (`@mermaid-js/mermaid-cli`).
18///
19/// Two input shapes are handled:
20///   * ` ```mermaid ` fenced code blocks - replaced with
21///     `<MermaidDiagram chart="..." {mode}Svg="<svg...>" ... />`.
22///   * Author-written `<MermaidDiagram chart={`...`} />` JSX nodes - the
23///     existing JSX node is preserved and `{mode}Svg` attributes are
24///     appended.
25///
26/// Theme behavior is driven by [`MermaidOptions::theme`]:
27/// `Single("dark")` renders once and emits a single `chartSvg` attr;
28/// `Multi({ light: "default", dark: "dark" })` (the default) renders
29/// per-mode and emits `lightSvg` + `darkSvg`.
30///
31/// Per-block failures emit [`Code::MermaidRenderFailed`]. The CLI
32/// availability probe runs once per process; missing CLI -> the whole
33/// transformer becomes a no-op with [`Code::MmdcUnavailable`].
34pub struct Mermaid {
35  opts: MermaidOptions,
36  /// Rendered-SVG cache keyed by `(theme, source)` hash, dedupes
37  /// identical diagrams across a single compile run.
38  cache: Mutex<HashMap<u64, String>>,
39}
40
41/// One-shot CLI availability probe.
42static MMDC_AVAILABLE: OnceLock<bool> = OnceLock::new();
43
44impl Default for Mermaid {
45  fn default() -> Self {
46    Self::from_options(MermaidOptions::default())
47  }
48}
49
50impl Mermaid {
51  /// Build a Mermaid transformer with the supplied options. Use
52  /// `Mermaid::default()` for the bundled defaults.
53  pub fn from_options(opts: MermaidOptions) -> Self {
54    Self { opts, cache: Mutex::new(HashMap::new()) }
55  }
56
57  /// Convenience constructor preserved for backward compat: enables
58  /// the disk cache at `dir`.
59  pub fn with_output(p: impl Into<PathBuf>) -> Self {
60    Self::from_options(MermaidOptions { output_dir: Some(p.into()), ..Default::default() })
61  }
62
63  fn mmdc_available() -> bool {
64    *MMDC_AVAILABLE.get_or_init(|| {
65      Command::new("mmdc")
66        .arg("--version")
67        .stdout(Stdio::null())
68        .stderr(Stdio::null())
69        .status()
70        .map(|s| s.success())
71        .unwrap_or(false)
72    })
73  }
74
75  /// Iterate the requested modes as `(jsx_attr_name, mermaid_theme)` pairs.
76  /// `Single("dark")` -> `[("chartSvg", "dark")]`.
77  /// `Multi({"light":"default","dark":"dark"})` -> `[("lightSvg","default"), ("darkSvg","dark")]`.
78  fn theme_renders(&self) -> Vec<(String, String)> {
79    match &self.opts.theme {
80      MermaidThemeMode::Single(name) => vec![("chartSvg".to_string(), name.clone())],
81      MermaidThemeMode::Multi(map) => map.iter().map(|(k, v)| (format!("{k}Svg"), v.clone())).collect(),
82    }
83  }
84
85  fn render_cached(&self, source: &str, theme: &str) -> Result<String, String> {
86    let key = {
87      use std::hash::{Hash, Hasher};
88      let mut hasher = std::collections::hash_map::DefaultHasher::new();
89      theme.hash(&mut hasher);
90      source.hash(&mut hasher);
91      hasher.finish()
92    };
93
94    if let Some(svg) = self.cache.lock().unwrap().get(&key) {
95      return Ok(svg.clone());
96    }
97
98    if let Some(dir) = &self.opts.output_dir {
99      let path = dir.join(format!("{key}.svg"));
100      match std::fs::read_to_string(&path) {
101        Ok(s) => return Ok(s),
102        Err(e) => {
103          if e.kind() != std::io::ErrorKind::NotFound {
104            return Err(e.to_string());
105          }
106        },
107      }
108    }
109
110    let svg = self.render_mmdc(source, theme)?;
111    self.cache.lock().unwrap().insert(key, svg.clone());
112    if let Some(dir) = &self.opts.output_dir {
113      let _ = std::fs::create_dir_all(dir);
114      let path = dir.join(format!("{key}.svg"));
115      let _ = std::fs::write(&path, &svg).map_err(|e| e.to_string());
116    }
117
118    Ok(svg)
119  }
120
121  /// Build the mermaid `initialize` config that goes to
122  /// `mmdc --configFile`. dmc defaults: `htmlLabels:false` (root + nested
123  /// flowchart for safety), flowchart spacing knobs. User-supplied
124  /// initialize fields overlay these defaults via shallow merge.
125  fn build_mermaid_config(&self) -> serde_json::Value {
126    let html_labels = self.opts.html_labels.unwrap_or(false);
127    let mut base = serde_json::json!({
128      "htmlLabels": html_labels,
129      "flowchart": {
130        "htmlLabels": html_labels,
131        "useMaxWidth": true,
132        "nodeSpacing": 50,
133        "rankSpacing": 60,
134        "padding": 20,
135      }
136    });
137    // Serialise the full options struct, then strip dmc-side keys
138    // (`theme`, `responsiveSvg`, `centerLabels`, `outputDir`,
139    // `puppeteerConfigFile`, `backgroundColor`) - every remaining field
140    // is part of the typed `mermaid.initialize()` surface.
141    if let Ok(serde_json::Value::Object(mut user)) = serde_json::to_value(&self.opts) {
142      for k in ["theme", "responsiveSvg", "centerLabels", "outputDir", "puppeteerConfigFile", "backgroundColor"] {
143        user.remove(k);
144      }
145      if !user.is_empty() {
146        shallow_merge(&mut base, &serde_json::Value::Object(user));
147      }
148    }
149    base
150  }
151
152  /// Run `mmdc` once for the given mermaid `source` + `theme`. Captures
153  /// stdout (the SVG markup); maps non-zero exit / stderr to an error.
154  fn render_mmdc(&self, source: &str, theme: &str) -> Result<String, String> {
155    let cfg_json = self.build_mermaid_config();
156    let cfg_str = cfg_json.to_string();
157    // SEC-005: write the `mmdc --configFile` config to a `tempfile`
158    // NamedTempFile — a randomised, exclusively-created path (no
159    // predictable name in the world-writable system temp dir, so a
160    // local attacker can't pre-plant or symlink it). The handle is held
161    // for the lifetime of the `mmdc` child and the file is unlinked on
162    // drop.
163    let mut cfg_file = tempfile::Builder::new()
164      .prefix("dmc-mermaid-config-")
165      .suffix(".json")
166      .tempfile()
167      .map_err(|e| format!("config temp file failed: {e}"))?;
168    cfg_file.write_all(cfg_str.as_bytes()).map_err(|e| format!("config write failed: {e}"))?;
169    cfg_file.flush().map_err(|e| format!("config write failed: {e}"))?;
170
171    let bg = self.opts.background_color.as_deref().unwrap_or("transparent");
172    let cfg_path_str = cfg_file.path().to_str().unwrap_or("").to_string();
173
174    let mut args: Vec<String> = vec![
175      "--input".into(),
176      "-".into(),
177      "--output".into(),
178      "-".into(),
179      "--outputFormat".into(),
180      "svg".into(),
181      "--theme".into(),
182      theme.to_string(),
183      "--backgroundColor".into(),
184      bg.to_string(),
185      "--configFile".into(),
186      cfg_path_str,
187      "--quiet".into(),
188    ];
189    if let Some(p) = &self.opts.puppeteer_config_file {
190      args.push("--puppeteerConfigFile".into());
191      args.push(p.to_string_lossy().into_owned());
192    }
193
194    let mut child = Command::new("mmdc")
195      .args(&args)
196      .stdin(Stdio::piped())
197      .stdout(Stdio::piped())
198      .stderr(Stdio::piped())
199      .spawn()
200      .map_err(|e| format!("spawn failed: {e}"))?;
201    child
202      .stdin
203      .as_mut()
204      .ok_or_else(|| "no stdin handle".to_string())?
205      .write_all(source.as_bytes())
206      .map_err(|e| format!("stdin write failed: {e}"))?;
207    let out = child.wait_with_output().map_err(|e| format!("wait failed: {e}"))?;
208    if !out.status.success() {
209      let err = String::from_utf8_lossy(&out.stderr).into_owned();
210      return Err(if err.is_empty() { format!("exit {}", out.status) } else { err });
211    }
212    let svg = String::from_utf8(out.stdout).map_err(|e| format!("non-utf8 svg: {e}"))?;
213    Ok(self.post_process(&svg))
214  }
215
216  /// Apply optional SVG post-processing: responsive width, centered
217  /// labels. Both default-on; toggleable via `MermaidOptions`.
218  fn post_process(&self, svg: &str) -> String {
219    let mut out = svg.to_string();
220    if self.opts.responsive_svg.unwrap_or(true) {
221      out = make_responsive(&out);
222    }
223    if self.opts.center_labels.unwrap_or(true) {
224      // Only meaningful with htmlLabels:false. Cheap no-op otherwise.
225      out = center_labels(&out);
226    }
227    out
228  }
229
230  /// Render every requested theme for `chart`, returning a map of
231  /// `{ jsx_attr_name -> svg_string }`. `None` if any theme errors out
232  /// (caller emits a diagnostic then).
233  fn render_all(
234    &self,
235    chart: &str,
236    span: &duck_diagnostic::Span,
237    pending: &mut Vec<Diagnostic<Code>>,
238  ) -> Option<BTreeMap<String, String>> {
239    let mut out = BTreeMap::new();
240    for (attr, theme) in self.theme_renders() {
241      match self.render_cached(chart, &theme) {
242        Ok(s) => {
243          out.insert(attr, s);
244        },
245        Err(err) => {
246          pending.push(
247            diag!(Code::MermaidRenderFailed, format!("mermaid ({theme}): mmdc failed - {}", err.trim()))
248              .with_label(Label::primary(span.clone(), Some("for this mermaid block".into()))),
249          );
250          return None;
251        },
252      }
253    }
254    Some(out)
255  }
256}
257
258/// Shallow-merge `extra` into `base` when both are JSON objects: keys in
259/// `extra` overwrite keys in `base`. Non-object `extra` is ignored. We
260/// intentionally don't recurse - mermaid's nested config (`flowchart`,
261/// `themeVariables`, ...) is small enough that a user passing a partial
262/// `flowchart` block should fully override our defaults for that block.
263fn shallow_merge(base: &mut serde_json::Value, extra: &serde_json::Value) {
264  use serde_json::Value;
265  if let (Value::Object(b), Value::Object(e)) = (base, extra) {
266    for (k, v) in e {
267      b.insert(k.clone(), v.clone());
268    }
269  }
270}
271
272/// Rewrite the first `width="..."` on the root `<svg>` element to
273/// `width="100%"` so the rendered diagram fluidly scales to its
274/// container.
275fn make_responsive(svg: &str) -> String {
276  if let Some(idx) = svg.find("width=\"")
277    && let Some(end) = svg[idx + "width=\"".len()..].find('"')
278  {
279    let head = &svg[..idx];
280    let tail = &svg[idx + "width=\"".len() + end + 1..];
281    return format!("{head}width=\"100%\"{tail}");
282  }
283  svg.to_string()
284}
285
286/// With `htmlLabels:false` mermaid 11 emits node `<text>` tags with no
287/// `text-anchor`, and an inner `<tspan x="0">` that pins itself to the
288/// label's local origin - i.e. node center. Result: label text starts
289/// at the rect's mid-point and bleeds off the right edge ("Accordion" ->
290/// "Accordio"). Inject `text-anchor="middle"` on the outer text/tspans
291/// so the `x="0"` becomes the *midpoint* of the line.
292fn center_labels(svg: &str) -> String {
293  let mut out = svg.replace("<text y=\"-10.1\"", "<text y=\"-10.1\" text-anchor=\"middle\"");
294  out = out.replace(
295    "<tspan class=\"text-outer-tspan row\" x=\"0\"",
296    "<tspan class=\"text-outer-tspan row\" x=\"0\" text-anchor=\"middle\"",
297  );
298  out
299}
300
301impl Transformer for Mermaid {
302  fn name(&self) -> &str {
303    "mermaid"
304  }
305  fn transform(
306    &self,
307    doc: &mut Document,
308    _meta: &SourceMeta,
309    diag_engine: &mut duck_diagnostic::DiagnosticEngine<Code>,
310  ) {
311    if !Self::mmdc_available() {
312      diag_engine.emit(diag!(
313        Code::MmdcUnavailable,
314        "mermaid: `mmdc` is not on PATH; mermaid blocks left as code (install with `npm i -g @mermaid-js/mermaid-cli`)"
315      ));
316      return;
317    }
318    let mut v = Apply { pending: Vec::new(), mermaid: self };
319    walk_root(&mut doc.children, &mut v);
320    for d in v.pending.drain(..) {
321      diag_engine.emit(d);
322    }
323  }
324}
325
326struct Apply<'a> {
327  pending: Vec<Diagnostic<Code>>,
328  mermaid: &'a Mermaid,
329}
330
331impl<'a> Apply<'a> {
332  /// Build the JSX attr list for a `<MermaidDiagram>` node: keep author
333  /// extras (className, etc), but always (re)set `chart` and every
334  /// rendered `${mode}Svg` from `svgs`.
335  fn jsx_attrs_with_svgs(
336    chart: String,
337    svgs: BTreeMap<String, String>,
338    span: &duck_diagnostic::Span,
339    extra: Vec<JsxAttr>,
340  ) -> Vec<JsxAttr> {
341    let svg_keys: std::collections::HashSet<&str> = svgs.keys().map(String::as_str).collect();
342    let mut out: Vec<JsxAttr> =
343      extra.into_iter().filter(|a| a.name != "chart" && !svg_keys.contains(a.name.as_str())).collect();
344    out.push(JsxAttr { name: "chart".into(), value: JsxAttrValue::String(chart), span: span.clone() });
345    for (k, v) in svgs {
346      out.push(JsxAttr { name: k, value: JsxAttrValue::String(v), span: span.clone() });
347    }
348    out
349  }
350}
351
352impl<'a> Visitor for Apply<'a> {
353  fn visit_node(&mut self, node: &mut Node) -> NodeAction {
354    match node {
355      // ```mermaid ... ``` -> <MermaidDiagram chart {modeKey}Svg ... />
356      Node::CodeBlock(cb) if cb.lang.as_deref() == Some("mermaid") => {
357        let span = cb.span.clone();
358        let chart = cb.value.clone();
359        let Some(svgs) = self.mermaid.render_all(&chart, &span, &mut self.pending) else {
360          return NodeAction::Keep;
361        };
362        let attrs = Apply::jsx_attrs_with_svgs(chart, svgs, &span, Vec::new());
363        *node = Node::JsxSelfClosing(JsxSelfClosing { name: "MermaidDiagram".into(), attrs, span });
364        NodeAction::KeepSkipChildren
365      },
366      // <MermaidDiagram chart={`...`} /> -> same, with {modeKey}Svg appended.
367      Node::JsxSelfClosing(jsc) if jsc.name == "MermaidDiagram" => {
368        let span = jsc.span.clone();
369        let Some(chart) = extract_chart_attr(&jsc.attrs) else { return NodeAction::Keep };
370        let Some(svgs) = self.mermaid.render_all(&chart, &span, &mut self.pending) else {
371          return NodeAction::Keep;
372        };
373        let extra = std::mem::take(&mut jsc.attrs);
374        jsc.attrs = Apply::jsx_attrs_with_svgs(chart, svgs, &span, extra);
375        NodeAction::KeepSkipChildren
376      },
377      Node::JsxElement(je) if je.name == "MermaidDiagram" => {
378        let span = je.span.clone();
379        let Some(chart) = extract_chart_attr(&je.attrs) else { return NodeAction::Keep };
380        let Some(svgs) = self.mermaid.render_all(&chart, &span, &mut self.pending) else {
381          return NodeAction::Keep;
382        };
383        let extra = std::mem::take(&mut je.attrs);
384        je.attrs = Apply::jsx_attrs_with_svgs(chart, svgs, &span, extra);
385        NodeAction::KeepSkipChildren
386      },
387      _ => NodeAction::Keep,
388    }
389  }
390}
391
392/// Pull the `chart` attribute value out as a plain string. Handles both
393/// `chart="..."` (string literal) and `chart={`...`}` /
394/// `chart={"..."}` (expression carrying a single string / template). The
395/// expression branch trims the surrounding `"..."` or `` `...` `` so the
396/// extracted text is mermaid source ready for `mmdc`.
397fn extract_chart_attr(attrs: &[JsxAttr]) -> Option<String> {
398  let attr = attrs.iter().find(|a| a.name == "chart")?;
399  match &attr.value {
400    JsxAttrValue::String(s) => Some(s.clone()),
401    JsxAttrValue::Expression(e) => {
402      let t = e.trim();
403      if (t.starts_with('`') && t.ends_with('`'))
404        || (t.starts_with('"') && t.ends_with('"'))
405        || (t.starts_with('\'') && t.ends_with('\''))
406      {
407        Some(t[1..t.len() - 1].to_string())
408      } else {
409        None
410      }
411    },
412    JsxAttrValue::Boolean | JsxAttrValue::Spread(_) => None,
413  }
414}