1use crate::{MermaidConfig, Result};
2use regex::Regex;
3use std::borrow::Cow;
4use std::sync::OnceLock;
5
6macro_rules! cached_regex {
7 ($fn_name:ident, $pat:literal) => {
8 fn $fn_name() -> &'static Regex {
9 static RE: OnceLock<Regex> = OnceLock::new();
10 RE.get_or_init(|| Regex::new($pat).expect("detector regex must compile"))
11 }
12 };
13}
14
15#[derive(Debug, thiserror::Error)]
16#[error("No diagram type detected matching given configuration for text: {text}")]
17pub struct DetectTypeError {
18 pub text: String,
19}
20
21pub type DetectorFn = fn(text: &str, config: &mut MermaidConfig) -> bool;
22
23#[derive(Debug, Clone)]
24pub struct Detector {
25 pub id: &'static str,
26 pub detector: DetectorFn,
27}
28
29#[derive(Debug, Clone)]
30pub struct DetectorRegistry {
31 detectors: Vec<Detector>,
32 frontmatter_re: Regex,
33 any_comment_re: Regex,
34}
35
36impl DetectorRegistry {
37 pub fn new() -> Self {
38 Self {
39 detectors: Vec::new(),
40 frontmatter_re: Regex::new(r"(?s)^\s*-{3}\s*[\n\r](.*?)[\n\r]\s*-{3}\s*[\n\r]+")
44 .unwrap(),
45 any_comment_re: Regex::new(r"(?m)\s*%%.*\n").unwrap(),
46 }
47 }
48
49 pub fn add(&mut self, detector: Detector) {
50 self.detectors.push(detector);
51 }
52
53 pub fn add_fn(&mut self, id: &'static str, detector: DetectorFn) {
54 self.add(Detector { id, detector });
55 }
56
57 pub fn detect_type(&self, text: &str, config: &mut MermaidConfig) -> Result<&'static str> {
58 let no_frontmatter = self.frontmatter_re.replace(text, "");
59 let no_directives = remove_directives(no_frontmatter.as_ref());
60 let cleaned = self
61 .any_comment_re
62 .replace_all(no_directives.as_ref(), "\n");
63
64 if let Some(id) = fast_detect_by_leading_keyword(cleaned.as_ref()) {
65 return Ok(id);
66 }
67
68 for det in &self.detectors {
69 if (det.detector)(cleaned.as_ref(), config) {
70 return Ok(det.id);
71 }
72 }
73
74 Err(DetectTypeError {
75 text: cleaned.into_owned(),
76 }
77 .into())
78 }
79
80 pub fn detect_type_precleaned(
83 &self,
84 text: &str,
85 config: &mut MermaidConfig,
86 ) -> Result<&'static str> {
87 if let Some(id) = fast_detect_by_leading_keyword(text) {
88 return Ok(id);
89 }
90
91 for det in &self.detectors {
92 if (det.detector)(text, config) {
93 return Ok(det.id);
94 }
95 }
96
97 Err(DetectTypeError {
98 text: text.to_string(),
99 }
100 .into())
101 }
102
103 pub fn default_mermaid_11_12_2_full() -> Self {
104 let mut reg = Self::new();
105
106 reg.add_fn("error", detector_error);
108 reg.add_fn("---", detector_frontmatter_unparsed);
109
110 reg.add_fn("flowchart-elk", detector_flowchart_elk);
112 reg.add_fn("mindmap", detector_mindmap);
113 reg.add_fn("architecture", detector_architecture);
114 reg.add_fn("zenuml", detector_zenuml);
115
116 reg.add_fn("c4", detector_c4);
118 reg.add_fn("kanban", detector_kanban);
119 reg.add_fn("classDiagram", detector_class_v2);
120 reg.add_fn("class", detector_class_dagre_d3);
121 reg.add_fn("er", detector_er);
122 reg.add_fn("gantt", detector_gantt);
123 reg.add_fn("info", detector_info);
124 reg.add_fn("pie", detector_pie);
125 reg.add_fn("requirement", detector_requirement);
126 reg.add_fn("sequence", detector_sequence);
127 reg.add_fn("flowchart-v2", detector_flowchart_v2);
128 reg.add_fn("flowchart", detector_flowchart_dagre_d3_graph);
129 reg.add_fn("timeline", detector_timeline);
130 reg.add_fn("gitGraph", detector_git_graph);
131 reg.add_fn("stateDiagram", detector_state_v2);
132 reg.add_fn("state", detector_state_dagre_d3);
133 reg.add_fn("journey", detector_journey);
134 reg.add_fn("quadrantChart", detector_quadrant);
135 reg.add_fn("sankey", detector_sankey);
136 reg.add_fn("packet", detector_packet);
137 reg.add_fn("xychart", detector_xychart);
138 reg.add_fn("block", detector_block);
139 reg.add_fn("radar", detector_radar);
140 reg.add_fn("treemap", detector_treemap);
141
142 reg
143 }
144
145 pub fn default_mermaid_11_12_2_tiny() -> Self {
146 let mut reg = Self::new();
147
148 reg.add_fn("error", detector_error);
150 reg.add_fn("---", detector_frontmatter_unparsed);
151
152 reg.add_fn("zenuml", detector_zenuml);
154 reg.add_fn("c4", detector_c4);
155 reg.add_fn("kanban", detector_kanban);
156 reg.add_fn("classDiagram", detector_class_v2);
157 reg.add_fn("class", detector_class_dagre_d3);
158 reg.add_fn("er", detector_er);
159 reg.add_fn("gantt", detector_gantt);
160 reg.add_fn("info", detector_info);
161 reg.add_fn("pie", detector_pie);
162 reg.add_fn("requirement", detector_requirement);
163 reg.add_fn("sequence", detector_sequence);
164 reg.add_fn("flowchart-v2", detector_flowchart_v2);
165 reg.add_fn("flowchart", detector_flowchart_dagre_d3_graph);
166 reg.add_fn("timeline", detector_timeline);
167 reg.add_fn("gitGraph", detector_git_graph);
168 reg.add_fn("stateDiagram", detector_state_v2);
169 reg.add_fn("state", detector_state_dagre_d3);
170 reg.add_fn("journey", detector_journey);
171 reg.add_fn("quadrantChart", detector_quadrant);
172 reg.add_fn("sankey", detector_sankey);
173 reg.add_fn("packet", detector_packet);
174 reg.add_fn("xychart", detector_xychart);
175 reg.add_fn("block", detector_block);
176 reg.add_fn("radar", detector_radar);
177 reg.add_fn("treemap", detector_treemap);
178
179 reg
180 }
181
182 #[cfg(feature = "large-features")]
183 pub fn default_mermaid_11_12_2() -> Self {
184 Self::default_mermaid_11_12_2_full()
185 }
186
187 #[cfg(not(feature = "large-features"))]
188 pub fn default_mermaid_11_12_2() -> Self {
189 Self::default_mermaid_11_12_2_tiny()
190 }
191}
192
193fn fast_detect_by_leading_keyword(text: &str) -> Option<&'static str> {
194 fn has_boundary(rest: &str) -> bool {
195 rest.is_empty()
196 || rest
197 .chars()
198 .next()
199 .is_some_and(|c| c.is_whitespace() || c == ';')
200 }
201
202 let t = text.trim_start();
203
204 if let Some(rest) = t.strip_prefix("sequenceDiagram") {
207 return has_boundary(rest).then_some("sequence");
208 }
209 if let Some(rest) = t.strip_prefix("classDiagram") {
210 return has_boundary(rest).then_some("classDiagram");
211 }
212 if let Some(rest) = t.strip_prefix("stateDiagram") {
213 return has_boundary(rest).then_some("stateDiagram");
214 }
215 if let Some(rest) = t.strip_prefix("mindmap") {
216 return has_boundary(rest).then_some("mindmap");
217 }
218 if let Some(rest) = t.strip_prefix("architecture") {
219 return has_boundary(rest).then_some("architecture");
220 }
221 if let Some(rest) = t.strip_prefix("erDiagram") {
222 return has_boundary(rest).then_some("er");
223 }
224 if let Some(rest) = t.strip_prefix("gantt") {
225 return has_boundary(rest).then_some("gantt");
226 }
227 if let Some(rest) = t.strip_prefix("timeline") {
228 return has_boundary(rest).then_some("timeline");
229 }
230 if let Some(rest) = t.strip_prefix("journey") {
231 return has_boundary(rest).then_some("journey");
232 }
233 if let Some(rest) = t.strip_prefix("gitGraph") {
234 return has_boundary(rest).then_some("gitGraph");
235 }
236 if let Some(rest) = t.strip_prefix("quadrantChart") {
237 return has_boundary(rest).then_some("quadrantChart");
238 }
239 if let Some(rest) = t.strip_prefix("packet-beta") {
240 return has_boundary(rest).then_some("packet");
241 }
242 if let Some(rest) = t.strip_prefix("xychart-beta") {
243 return has_boundary(rest).then_some("xychart");
244 }
245
246 None
247}
248
249fn remove_directives(text: &str) -> Cow<'_, str> {
250 if !text.contains("%%{") {
251 return Cow::Borrowed(text);
252 }
253
254 let mut out = String::with_capacity(text.len());
255 let mut pos = 0;
256 while let Some(rel) = text[pos..].find("%%{") {
257 let start = pos + rel;
258 out.push_str(&text[pos..start]);
259 let after_start = start + 3;
260 if let Some(rel_end) = text[after_start..].find("}%%") {
261 let end = after_start + rel_end + 3;
262 pos = end;
263 } else {
264 return Cow::Owned(out);
265 }
266 }
267 out.push_str(&text[pos..]);
268 Cow::Owned(out)
269}
270
271cached_regex!(
272 re_c4,
273 r"^\s*C4Context|C4Container|C4Component|C4Dynamic|C4Deployment"
274);
275
276impl Default for DetectorRegistry {
277 fn default() -> Self {
278 Self::new()
279 }
280}
281
282fn detector_frontmatter_unparsed(txt: &str, _config: &mut MermaidConfig) -> bool {
283 txt.trim_start().starts_with("---")
284}
285
286fn detector_error(txt: &str, _config: &mut MermaidConfig) -> bool {
287 txt.trim().eq_ignore_ascii_case("error")
288}
289
290fn detector_c4(txt: &str, _config: &mut MermaidConfig) -> bool {
291 re_c4().is_match(txt)
293}
294
295fn detector_kanban(txt: &str, _config: &mut MermaidConfig) -> bool {
296 txt.trim_start().starts_with("kanban")
297}
298
299fn detector_class_dagre_d3(txt: &str, config: &mut MermaidConfig) -> bool {
300 if config.get_str("class.defaultRenderer") == Some("dagre-wrapper") {
301 return false;
302 }
303 txt.trim_start().starts_with("classDiagram")
304}
305
306fn detector_class_v2(txt: &str, config: &mut MermaidConfig) -> bool {
307 if txt.trim_start().starts_with("classDiagram")
308 && config.get_str("class.defaultRenderer") == Some("dagre-wrapper")
309 {
310 return true;
311 }
312 txt.trim_start().starts_with("classDiagram-v2")
313}
314
315fn detector_er(txt: &str, _config: &mut MermaidConfig) -> bool {
316 txt.trim_start().starts_with("erDiagram")
317}
318
319fn detector_gantt(txt: &str, _config: &mut MermaidConfig) -> bool {
320 txt.trim_start().starts_with("gantt")
321}
322
323fn detector_info(txt: &str, _config: &mut MermaidConfig) -> bool {
324 txt.trim_start().starts_with("info")
325}
326
327fn detector_pie(txt: &str, _config: &mut MermaidConfig) -> bool {
328 txt.trim_start().starts_with("pie")
329}
330
331fn detector_requirement(txt: &str, _config: &mut MermaidConfig) -> bool {
332 txt.trim_start().starts_with("requirement")
333}
334
335fn detector_sequence(txt: &str, _config: &mut MermaidConfig) -> bool {
336 txt.trim_start().starts_with("sequenceDiagram")
337}
338
339fn detector_flowchart_elk(txt: &str, config: &mut MermaidConfig) -> bool {
340 let trimmed = txt.trim_start();
341 if trimmed.starts_with("flowchart-elk")
342 || ((trimmed.starts_with("flowchart") || trimmed.starts_with("graph"))
343 && config.get_str("flowchart.defaultRenderer") == Some("elk"))
344 {
345 config.set_value("layout", serde_json::Value::String("elk".to_string()));
346 return true;
347 }
348 false
349}
350
351fn detector_flowchart_v2(txt: &str, config: &mut MermaidConfig) -> bool {
352 if config.get_str("flowchart.defaultRenderer") == Some("dagre-d3") {
353 return false;
354 }
355 if config.get_str("flowchart.defaultRenderer") == Some("elk") {
356 config.set_value("layout", serde_json::Value::String("elk".to_string()));
357 }
358
359 if txt.trim_start().starts_with("graph")
360 && config.get_str("flowchart.defaultRenderer") == Some("dagre-wrapper")
361 {
362 return true;
363 }
364 txt.trim_start().starts_with("flowchart")
365}
366
367fn detector_flowchart_dagre_d3_graph(txt: &str, config: &mut MermaidConfig) -> bool {
368 if matches!(
369 config.get_str("flowchart.defaultRenderer"),
370 Some("dagre-wrapper" | "elk")
371 ) {
372 return false;
373 }
374 txt.trim_start().starts_with("graph")
375}
376
377fn detector_timeline(txt: &str, _config: &mut MermaidConfig) -> bool {
378 txt.trim_start().starts_with("timeline")
379}
380
381fn detector_git_graph(txt: &str, _config: &mut MermaidConfig) -> bool {
382 txt.trim_start().starts_with("gitGraph")
383}
384
385fn detector_state_dagre_d3(txt: &str, config: &mut MermaidConfig) -> bool {
386 if config.get_str("state.defaultRenderer") == Some("dagre-wrapper") {
387 return false;
388 }
389 txt.trim_start().starts_with("stateDiagram")
390}
391
392fn detector_state_v2(txt: &str, config: &mut MermaidConfig) -> bool {
393 let trimmed = txt.trim_start();
394 if trimmed.starts_with("stateDiagram-v2") {
395 return true;
396 }
397 trimmed.starts_with("stateDiagram")
398 && config.get_str("state.defaultRenderer") == Some("dagre-wrapper")
399}
400
401fn detector_journey(txt: &str, _config: &mut MermaidConfig) -> bool {
402 txt.trim_start().starts_with("journey")
403}
404
405fn detector_quadrant(txt: &str, _config: &mut MermaidConfig) -> bool {
406 txt.trim_start().starts_with("quadrantChart")
407}
408
409fn detector_sankey(txt: &str, _config: &mut MermaidConfig) -> bool {
410 txt.trim_start().starts_with("sankey")
411}
412
413fn detector_packet(txt: &str, _config: &mut MermaidConfig) -> bool {
414 txt.trim_start().starts_with("packet")
415}
416
417fn detector_xychart(txt: &str, _config: &mut MermaidConfig) -> bool {
418 txt.trim_start().starts_with("xychart")
419}
420
421fn detector_block(txt: &str, _config: &mut MermaidConfig) -> bool {
422 txt.trim_start().starts_with("block")
423}
424
425fn detector_radar(txt: &str, _config: &mut MermaidConfig) -> bool {
426 txt.trim_start().starts_with("radar-beta")
427}
428
429fn detector_treemap(txt: &str, _config: &mut MermaidConfig) -> bool {
430 txt.trim_start().starts_with("treemap")
431}
432
433fn detector_mindmap(txt: &str, _config: &mut MermaidConfig) -> bool {
434 txt.trim_start().starts_with("mindmap")
435}
436
437fn detector_architecture(txt: &str, _config: &mut MermaidConfig) -> bool {
438 txt.trim_start().starts_with("architecture")
439}
440
441fn detector_zenuml(txt: &str, _config: &mut MermaidConfig) -> bool {
442 txt.trim_start().starts_with("zenuml")
443}
444
445#[cfg(test)]
446mod remove_directives_tests {
447 use super::remove_directives;
448 use std::borrow::Cow;
449
450 #[test]
451 fn no_directives_is_borrowed() {
452 let s = "flowchart TD; A-->B;";
453 assert!(matches!(remove_directives(s), Cow::Borrowed(_)));
454 }
455
456 #[test]
457 fn removes_directive_block() {
458 let s = "%%{init: {\"theme\": \"dark\"}}%%\nflowchart TD; A-->B;";
459 let out = remove_directives(s);
460 assert!(out.as_ref().contains("flowchart TD"));
461 assert!(!out.as_ref().contains("init"));
462 }
463
464 #[test]
465 fn unterminated_directive_truncates_at_start() {
466 let s = "flowchart\n%%{init: {\"theme\": \"dark\"}}\nA-->B;";
467 let out = remove_directives(s);
468 assert_eq!(out.as_ref().trim(), "flowchart");
469 }
470}