use comrak::nodes::{AstNode, NodeHtmlBlock, NodeValue};
pub fn process_diagram_code_blocks<'a>(root: &'a AstNode<'a>) {
for node in root.descendants() {
let mut ast = node.data.borrow_mut();
let replacement = match ast.value {
NodeValue::CodeBlock(ref block) => {
let kind = block
.info
.split_whitespace()
.next()
.unwrap_or("")
.to_ascii_lowercase();
if kind == "mermaid" {
Some(render_mermaid(&block.literal))
} else {
None
}
}
_ => None,
};
if let Some(html) = replacement {
ast.value = NodeValue::HtmlBlock(NodeHtmlBlock {
block_type: 6,
literal: html,
});
}
}
}
fn render_mermaid(source: &str) -> String {
let escaped = html_escape::encode_text(source);
format!("<pre class=\"mermaid\">{escaped}</pre>\n")
}
#[must_use]
pub fn hydration_script_html() -> &'static str {
HYDRATION_SCRIPT
}
const HYDRATION_SCRIPT: &str = include_str!("diagrams_hydrator.js");
#[cfg(test)]
mod tests {
use super::*;
use comrak::{parse_document, Arena, Options};
fn find_html_containing<'a>(
root: &'a AstNode<'a>,
needle: &str,
) -> Option<String> {
for node in root.descendants() {
if let NodeValue::HtmlBlock(ref block) =
node.data.borrow().value
{
if block.literal.contains(needle) {
return Some(block.literal.clone());
}
}
}
None
}
fn transform_and_find(
source: &str,
needle: &str,
) -> Option<String> {
let arena = Arena::new();
let root = parse_document(&arena, source, &Options::default());
process_diagram_code_blocks(root);
find_html_containing(root, needle)
}
#[test]
fn test_mermaid_block_rewritten() {
let md = "```mermaid\ngraph TD\n A --> B\n```\n";
let found = transform_and_find(md, "class=\"mermaid\"")
.expect("mermaid container missing");
assert!(found.starts_with("<pre class=\"mermaid\">"));
assert!(found.contains("graph TD"));
assert!(found.contains("A --> B"), "content escaped");
}
#[test]
fn test_info_string_with_attributes() {
let md =
"```mermaid classDiagram\nclassDiagram\n A<|--B\n```\n";
assert!(transform_and_find(md, "class=\"mermaid\"").is_some());
}
#[test]
fn test_unknown_lang_passes_through() {
let md = "```rust\nfn main() {}\n```\n";
assert!(transform_and_find(md, "class=\"mermaid\"").is_none());
}
#[test]
fn test_non_matching_diagram_langs_pass_through() {
for lang in ["geojson", "topojson", "stl"] {
let md = format!("```{lang}\n{{\"a\":1}}\n```\n");
assert!(
transform_and_find(&md, "class=\"mermaid\"").is_none(),
"{lang} should not produce a mermaid container"
);
}
}
#[test]
fn test_content_is_html_escaped() {
let md = "```mermaid\ngraph <script>alert(1)</script>\n```\n";
let found =
transform_and_find(md, "class=\"mermaid\"").unwrap();
assert!(!found.contains("<script>"));
assert!(found.contains("<script>"));
}
#[test]
fn test_hydration_script_imports_mermaid() {
let s = hydration_script_html();
assert!(s.contains("pre.mermaid"));
assert!(s.contains("mermaid"));
assert!(s.starts_with("<script type=\"module\">"));
assert!(s.trim_end().ends_with("</script>"));
}
}