1use crate::config::{MermaidOptions, MermaidThemeMode};
4use crate::pipeline::Transformer;
5use crate::visit::{NodeAction, Visitor, walk_root};
6use dmc_diagnostic::Code;
7use dmc_diagnostic::metadata::SourceMeta;
8use dmc_parser::ast::*;
9use duck_diagnostic::{Diagnostic, Label, diag};
10use std::collections::{BTreeMap, HashMap};
11use std::io::Write;
12use std::path::PathBuf;
13use std::process::{Command, Stdio};
14use std::sync::{Mutex, OnceLock};
15
16pub struct Mermaid {
35 opts: MermaidOptions,
36 cache: Mutex<HashMap<u64, String>>,
39}
40
41static MMDC_AVAILABLE: OnceLock<bool> = OnceLock::new();
43
44impl Default for Mermaid {
45 fn default() -> Self {
46 Self::from_options(MermaidOptions::default())
47 }
48}
49
50impl Mermaid {
51 pub fn from_options(opts: MermaidOptions) -> Self {
54 Self { opts, cache: Mutex::new(HashMap::new()) }
55 }
56
57 pub fn with_output(p: impl Into<PathBuf>) -> Self {
60 Self::from_options(MermaidOptions { output_dir: Some(p.into()), ..Default::default() })
61 }
62
63 fn mmdc_available() -> bool {
64 *MMDC_AVAILABLE.get_or_init(|| {
65 Command::new("mmdc")
66 .arg("--version")
67 .stdout(Stdio::null())
68 .stderr(Stdio::null())
69 .status()
70 .map(|s| s.success())
71 .unwrap_or(false)
72 })
73 }
74
75 fn theme_renders(&self) -> Vec<(String, String)> {
79 match &self.opts.theme {
80 MermaidThemeMode::Single(name) => vec![("chartSvg".to_string(), name.clone())],
81 MermaidThemeMode::Multi(map) => map.iter().map(|(k, v)| (format!("{k}Svg"), v.clone())).collect(),
82 }
83 }
84
85 fn render_cached(&self, source: &str, theme: &str) -> Result<String, String> {
86 let key = {
87 use std::hash::{Hash, Hasher};
88 let mut hasher = std::collections::hash_map::DefaultHasher::new();
89 theme.hash(&mut hasher);
90 source.hash(&mut hasher);
91 hasher.finish()
92 };
93
94 if let Some(svg) = self.cache.lock().unwrap().get(&key) {
95 return Ok(svg.clone());
96 }
97
98 if let Some(dir) = &self.opts.output_dir {
99 let path = dir.join(format!("{key}.svg"));
100 match std::fs::read_to_string(&path) {
101 Ok(s) => return Ok(s),
102 Err(e) => {
103 if e.kind() != std::io::ErrorKind::NotFound {
104 return Err(e.to_string());
105 }
106 },
107 }
108 }
109
110 let svg = self.render_mmdc(source, theme)?;
111 self.cache.lock().unwrap().insert(key, svg.clone());
112 if let Some(dir) = &self.opts.output_dir {
113 let _ = std::fs::create_dir_all(dir);
114 let path = dir.join(format!("{key}.svg"));
115 let _ = std::fs::write(&path, &svg).map_err(|e| e.to_string());
116 }
117
118 Ok(svg)
119 }
120
121 fn build_mermaid_config(&self) -> serde_json::Value {
126 let html_labels = self.opts.html_labels.unwrap_or(false);
127 let mut base = serde_json::json!({
128 "htmlLabels": html_labels,
129 "flowchart": {
130 "htmlLabels": html_labels,
131 "useMaxWidth": true,
132 "nodeSpacing": 50,
133 "rankSpacing": 60,
134 "padding": 20,
135 }
136 });
137 if let Ok(serde_json::Value::Object(mut user)) = serde_json::to_value(&self.opts) {
142 for k in ["theme", "responsiveSvg", "centerLabels", "outputDir", "puppeteerConfigFile", "backgroundColor"] {
143 user.remove(k);
144 }
145 if !user.is_empty() {
146 shallow_merge(&mut base, &serde_json::Value::Object(user));
147 }
148 }
149 base
150 }
151
152 fn render_mmdc(&self, source: &str, theme: &str) -> Result<String, String> {
155 let cfg_json = self.build_mermaid_config();
156 let cfg_str = cfg_json.to_string();
157 let cfg_dir = std::env::temp_dir();
158 let cfg_hash = {
161 use std::hash::{Hash, Hasher};
162 let mut hasher = std::collections::hash_map::DefaultHasher::new();
163 cfg_str.hash(&mut hasher);
164 hasher.finish()
165 };
166 let cfg_path = cfg_dir.join(format!("dmc-mermaid-config-{}-{cfg_hash:x}.json", std::process::id()));
167 if !cfg_path.exists() {
168 std::fs::write(&cfg_path, &cfg_str).map_err(|e| format!("config write failed: {e}"))?;
169 }
170
171 let bg = self.opts.background_color.as_deref().unwrap_or("transparent");
172 let cfg_path_str = cfg_path.to_str().unwrap_or("").to_string();
173
174 let mut args: Vec<String> = vec![
175 "--input".into(),
176 "-".into(),
177 "--output".into(),
178 "-".into(),
179 "--outputFormat".into(),
180 "svg".into(),
181 "--theme".into(),
182 theme.to_string(),
183 "--backgroundColor".into(),
184 bg.to_string(),
185 "--configFile".into(),
186 cfg_path_str,
187 "--quiet".into(),
188 ];
189 if let Some(p) = &self.opts.puppeteer_config_file {
190 args.push("--puppeteerConfigFile".into());
191 args.push(p.to_string_lossy().into_owned());
192 }
193
194 let mut child = Command::new("mmdc")
195 .args(&args)
196 .stdin(Stdio::piped())
197 .stdout(Stdio::piped())
198 .stderr(Stdio::piped())
199 .spawn()
200 .map_err(|e| format!("spawn failed: {e}"))?;
201 child
202 .stdin
203 .as_mut()
204 .ok_or_else(|| "no stdin handle".to_string())?
205 .write_all(source.as_bytes())
206 .map_err(|e| format!("stdin write failed: {e}"))?;
207 let out = child.wait_with_output().map_err(|e| format!("wait failed: {e}"))?;
208 if !out.status.success() {
209 let err = String::from_utf8_lossy(&out.stderr).into_owned();
210 return Err(if err.is_empty() { format!("exit {}", out.status) } else { err });
211 }
212 let svg = String::from_utf8(out.stdout).map_err(|e| format!("non-utf8 svg: {e}"))?;
213 Ok(self.post_process(&svg))
214 }
215
216 fn post_process(&self, svg: &str) -> String {
219 let mut out = svg.to_string();
220 if self.opts.responsive_svg.unwrap_or(true) {
221 out = make_responsive(&out);
222 }
223 if self.opts.center_labels.unwrap_or(true) {
224 out = center_labels(&out);
226 }
227 out
228 }
229
230 fn render_all(
234 &self,
235 chart: &str,
236 span: &duck_diagnostic::Span,
237 pending: &mut Vec<Diagnostic<Code>>,
238 ) -> Option<BTreeMap<String, String>> {
239 let mut out = BTreeMap::new();
240 for (attr, theme) in self.theme_renders() {
241 match self.render_cached(chart, &theme) {
242 Ok(s) => {
243 out.insert(attr, s);
244 },
245 Err(err) => {
246 pending.push(
247 diag!(Code::MermaidRenderFailed, format!("mermaid ({theme}): mmdc failed - {}", err.trim()))
248 .with_label(Label::primary(span.clone(), Some("for this mermaid block".into()))),
249 );
250 return None;
251 },
252 }
253 }
254 Some(out)
255 }
256}
257
258fn shallow_merge(base: &mut serde_json::Value, extra: &serde_json::Value) {
264 use serde_json::Value;
265 if let (Value::Object(b), Value::Object(e)) = (base, extra) {
266 for (k, v) in e {
267 b.insert(k.clone(), v.clone());
268 }
269 }
270}
271
272fn make_responsive(svg: &str) -> String {
276 if let Some(idx) = svg.find("width=\"")
277 && let Some(end) = svg[idx + "width=\"".len()..].find('"')
278 {
279 let head = &svg[..idx];
280 let tail = &svg[idx + "width=\"".len() + end + 1..];
281 return format!("{head}width=\"100%\"{tail}");
282 }
283 svg.to_string()
284}
285
286fn center_labels(svg: &str) -> String {
293 let mut out = svg.replace("<text y=\"-10.1\"", "<text y=\"-10.1\" text-anchor=\"middle\"");
294 out = out.replace(
295 "<tspan class=\"text-outer-tspan row\" x=\"0\"",
296 "<tspan class=\"text-outer-tspan row\" x=\"0\" text-anchor=\"middle\"",
297 );
298 out
299}
300
301impl Transformer for Mermaid {
302 fn name(&self) -> &str {
303 "mermaid"
304 }
305 fn transform(
306 &self,
307 doc: &mut Document,
308 _meta: &SourceMeta,
309 diag_engine: &mut duck_diagnostic::DiagnosticEngine<Code>,
310 ) {
311 if !Self::mmdc_available() {
312 diag_engine.emit(diag!(
313 Code::MmdcUnavailable,
314 "mermaid: `mmdc` is not on PATH; mermaid blocks left as code (install with `npm i -g @mermaid-js/mermaid-cli`)"
315 ));
316 return;
317 }
318 let mut v = Apply { pending: Vec::new(), mermaid: self };
319 walk_root(&mut doc.children, &mut v);
320 for d in v.pending.drain(..) {
321 diag_engine.emit(d);
322 }
323 }
324}
325
326struct Apply<'a> {
327 pending: Vec<Diagnostic<Code>>,
328 mermaid: &'a Mermaid,
329}
330
331impl<'a> Apply<'a> {
332 fn jsx_attrs_with_svgs(
336 chart: String,
337 svgs: BTreeMap<String, String>,
338 span: &duck_diagnostic::Span,
339 extra: Vec<JsxAttr>,
340 ) -> Vec<JsxAttr> {
341 let svg_keys: std::collections::HashSet<&str> = svgs.keys().map(String::as_str).collect();
342 let mut out: Vec<JsxAttr> =
343 extra.into_iter().filter(|a| a.name != "chart" && !svg_keys.contains(a.name.as_str())).collect();
344 out.push(JsxAttr { name: "chart".into(), value: JsxAttrValue::String(chart), span: span.clone() });
345 for (k, v) in svgs {
346 out.push(JsxAttr { name: k, value: JsxAttrValue::String(v), span: span.clone() });
347 }
348 out
349 }
350}
351
352impl<'a> Visitor for Apply<'a> {
353 fn visit_node(&mut self, node: &mut Node) -> NodeAction {
354 match node {
355 Node::CodeBlock(cb) if cb.lang.as_deref() == Some("mermaid") => {
357 let span = cb.span.clone();
358 let chart = cb.value.clone();
359 let Some(svgs) = self.mermaid.render_all(&chart, &span, &mut self.pending) else {
360 return NodeAction::Keep;
361 };
362 let attrs = Apply::jsx_attrs_with_svgs(chart, svgs, &span, Vec::new());
363 *node = Node::JsxSelfClosing(JsxSelfClosing { name: "MermaidDiagram".into(), attrs, span });
364 NodeAction::KeepSkipChildren
365 },
366 Node::JsxSelfClosing(jsc) if jsc.name == "MermaidDiagram" => {
368 let span = jsc.span.clone();
369 let Some(chart) = extract_chart_attr(&jsc.attrs) else { return NodeAction::Keep };
370 let Some(svgs) = self.mermaid.render_all(&chart, &span, &mut self.pending) else {
371 return NodeAction::Keep;
372 };
373 let extra = std::mem::take(&mut jsc.attrs);
374 jsc.attrs = Apply::jsx_attrs_with_svgs(chart, svgs, &span, extra);
375 NodeAction::KeepSkipChildren
376 },
377 Node::JsxElement(je) if je.name == "MermaidDiagram" => {
378 let span = je.span.clone();
379 let Some(chart) = extract_chart_attr(&je.attrs) else { return NodeAction::Keep };
380 let Some(svgs) = self.mermaid.render_all(&chart, &span, &mut self.pending) else {
381 return NodeAction::Keep;
382 };
383 let extra = std::mem::take(&mut je.attrs);
384 je.attrs = Apply::jsx_attrs_with_svgs(chart, svgs, &span, extra);
385 NodeAction::KeepSkipChildren
386 },
387 _ => NodeAction::Keep,
388 }
389 }
390}
391
392fn extract_chart_attr(attrs: &[JsxAttr]) -> Option<String> {
398 let attr = attrs.iter().find(|a| a.name == "chart")?;
399 match &attr.value {
400 JsxAttrValue::String(s) => Some(s.clone()),
401 JsxAttrValue::Expression(e) => {
402 let t = e.trim();
403 if (t.starts_with('`') && t.ends_with('`'))
404 || (t.starts_with('"') && t.ends_with('"'))
405 || (t.starts_with('\'') && t.ends_with('\''))
406 {
407 Some(t[1..t.len() - 1].to_string())
408 } else {
409 None
410 }
411 },
412 JsxAttrValue::Boolean | JsxAttrValue::Spread(_) => None,
413 }
414}