pub fn inject_mermaid_wrappers(html: &str) -> String {
let mut output = String::with_capacity(html.len());
let mut rest = html;
while let Some(pos) = find_mermaid_open(rest) {
output.push_str(&rest[..pos]);
let after_pre_start = pos;
let pre_tag_end = match rest[after_pre_start..].find('>') {
Some(p) => after_pre_start + p + 1,
None => {
output.push_str(&rest[pos..]);
return output;
}
};
let after_pre = &rest[pre_tag_end..];
let code_open = r#"<code class="language-mermaid">"#;
let code_start = match after_pre.find(code_open) {
Some(p) => p,
None => {
output.push_str(&rest[pos..]);
return output;
}
};
let content_start = pre_tag_end + code_start + code_open.len();
let close_pattern = "</code></pre>";
let content_end = match rest[content_start..].find(close_pattern) {
Some(p) => content_start + p,
None => {
output.push_str(&rest[pos..]);
return output;
}
};
let source = &rest[content_start..content_end];
let escaped = html_escape(source);
output.push_str(&format!(
r#"<div class="mermaid" data-source="{escaped}">{source}</div>"#,
));
rest = &rest[content_end + close_pattern.len()..];
}
output.push_str(rest);
output
}
fn find_mermaid_open(html: &str) -> Option<usize> {
let needle = "language-mermaid";
let mut search_from = 0;
loop {
let mermaid_pos = html[search_from..].find(needle)? + search_from;
if let Some(pre_pos) = html[..mermaid_pos].rfind("<pre") {
let between = &html[pre_pos..mermaid_pos];
if !between.contains("</pre>") {
return Some(pre_pos);
}
}
search_from = mermaid_pos + needle.len();
}
}
fn html_escape(s: &str) -> String {
s.replace('&', "&")
.replace('"', """)
.replace('<', "<")
.replace('>', ">")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mermaid_wrapper() {
let input = r#"<pre><code class="language-mermaid">graph TD; A-->B;</code></pre>"#;
let output = inject_mermaid_wrappers(input);
assert!(output.contains(r#"class="mermaid""#));
assert!(output.contains("graph TD; A-->B;"));
assert!(!output.contains("<pre>"));
assert!(!output.contains("<code"));
}
#[test]
fn test_mermaid_wrapper_highlighted_pre() {
let input = r#"<pre class="code-block"><code class="language-mermaid">graph LR; X-->Y;</code></pre>"#;
let output = inject_mermaid_wrappers(input);
assert!(output.contains(r#"class="mermaid""#));
assert!(output.contains("graph LR;"));
}
#[test]
fn test_no_mermaid_passthrough() {
let input = r#"<pre><code class="language-rust">fn main() {}</code></pre>"#;
let output = inject_mermaid_wrappers(input);
assert_eq!(output, input);
}
#[test]
fn test_multiple_mermaid_blocks() {
let input = concat!(
r#"<pre><code class="language-mermaid">A-->B</code></pre>"#,
"<p>text</p>",
r#"<pre><code class="language-mermaid">C-->D</code></pre>"#,
);
let output = inject_mermaid_wrappers(input);
assert_eq!(output.matches(r#"class="mermaid""#).count(), 2);
}
}