merman_render/svg/pipeline/builtin/
scoped_css.rs1use crate::Result;
2use std::borrow::Cow;
3
4use super::css_override::{CssOverridePolicy, strip_css_important};
5use super::util::{escape_xml_attr, find_matching_brace, find_tag_end};
6use crate::svg::pipeline::{SvgPostprocessContext, SvgPostprocessor};
7
8#[derive(Debug, Clone)]
9pub struct ScopedCssPostprocessor {
10 css: String,
11 override_policy: CssOverridePolicy,
12}
13
14impl ScopedCssPostprocessor {
15 pub fn new(css: impl Into<String>) -> Self {
16 Self {
17 css: css.into(),
18 override_policy: CssOverridePolicy::Preserve,
19 }
20 }
21
22 pub fn with_override_policy(mut self, policy: CssOverridePolicy) -> Self {
23 self.override_policy = policy;
24 self
25 }
26
27 pub fn css(&self) -> &str {
28 &self.css
29 }
30
31 pub fn override_policy(&self) -> CssOverridePolicy {
32 self.override_policy
33 }
34}
35
36impl SvgPostprocessor for ScopedCssPostprocessor {
37 fn name(&self) -> &'static str {
38 "scoped-css"
39 }
40
41 fn process<'a>(
42 &self,
43 svg: Cow<'a, str>,
44 ctx: &SvgPostprocessContext<'_>,
45 ) -> Result<Cow<'a, str>> {
46 if self.css.trim().is_empty() {
47 return Ok(svg);
48 }
49
50 let mut base = match self.override_policy {
51 CssOverridePolicy::Preserve => svg.into_owned(),
52 CssOverridePolicy::StripExistingImportant => strip_css_important(svg.as_ref()),
53 };
54 let scoped_css = scope_css(&self.css, ctx.svg_id());
55 inject_style(&mut base, &scoped_css);
56 Ok(Cow::Owned(base))
57 }
58}
59
60fn inject_style(svg: &mut String, css: &str) {
61 let css = css.replace("</style", "<\\/style");
62 let style = format!(
63 r#"<style data-merman-postprocess="scoped-css">{}</style>"#,
64 css
65 );
66
67 if let Some(start) = svg.find("<svg") {
68 if let Some(end) = find_tag_end(svg, start) {
69 if let Some(style_close_start) = svg.rfind("</style") {
70 if let Some(style_close_end) = find_tag_end(svg, style_close_start) {
71 svg.insert_str(style_close_end + 1, &style);
72 return;
73 }
74 }
75 svg.insert_str(end + 1, &style);
76 return;
77 }
78 }
79
80 svg.push_str(&style);
81}
82
83fn scope_css(css: &str, svg_id: Option<&str>) -> String {
84 let Some(svg_id) = svg_id.filter(|id| !id.trim().is_empty()) else {
85 return css.to_string();
86 };
87 let scope = format!("#{}", css_escape_id(svg_id));
88 let mut out = String::with_capacity(css.len() + scope.len() * 4);
89 let mut cursor = 0;
90
91 while let Some(rel_open) = css[cursor..].find('{') {
92 let open = cursor + rel_open;
93 let selector = &css[cursor..open];
94 let Some(close) = find_matching_brace(css, open) else {
95 out.push_str(&css[cursor..]);
96 return out;
97 };
98
99 if selector.trim_start().starts_with('@') {
100 out.push_str(&css[cursor..=close]);
101 } else {
102 out.push_str(&scope_selector(selector, &scope));
103 out.push(' ');
104 out.push_str(&css[open..=close]);
105 }
106 cursor = close + 1;
107 }
108
109 out.push_str(&css[cursor..]);
110 out
111}
112
113fn scope_selector(selector: &str, scope: &str) -> String {
114 selector
115 .split(',')
116 .map(|part| {
117 let trimmed = part.trim();
118 if trimmed.is_empty() {
119 String::new()
120 } else if trimmed.starts_with(scope) {
121 trimmed.to_string()
122 } else if trimmed == ":root" || trimmed == "svg" {
123 scope.to_string()
124 } else {
125 format!("{scope} {trimmed}")
126 }
127 })
128 .collect::<Vec<_>>()
129 .join(", ")
130}
131
132fn css_escape_id(id: &str) -> String {
133 let mut out = String::with_capacity(id.len());
134 for ch in id.chars() {
135 let ok = ch.is_ascii_alphanumeric() || ch == '-' || ch == '_';
136 if ok {
137 out.push(ch);
138 } else {
139 out.push('\\');
140 out.push(ch);
141 }
142 }
143 out
144}
145
146#[allow(dead_code)]
147fn scoped_attr_selector(id: &str) -> String {
148 format!(r#"svg[id="{}"]"#, escape_xml_attr(id))
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use crate::svg::pipeline::SvgPipeline;
155
156 #[test]
157 fn scoped_css_injects_after_root_svg_tag_when_no_style_exists() {
158 let svg = r#"<svg id="diagram"><rect class="node"/></svg>"#;
159 let out = SvgPipeline::parity()
160 .with_postprocessor(ScopedCssPostprocessor::new(
161 ".node rect, text.label { fill: red; }",
162 ))
163 .process_to_string(svg)
164 .unwrap();
165
166 assert!(out.starts_with(r#"<svg id="diagram"><style"#));
167 assert!(out.contains("#diagram .node rect, #diagram text.label { fill: red; }"));
168 }
169
170 #[test]
171 fn scoped_css_injects_after_existing_style_for_cascade_order() {
172 let svg =
173 r#"<svg id="diagram"><style>#diagram .node rect { fill: red; }</style><g/></svg>"#;
174 let out = SvgPipeline::parity()
175 .with_postprocessor(ScopedCssPostprocessor::new(".node rect { fill: green; }"))
176 .process_to_string(svg)
177 .unwrap();
178
179 let existing = out.find("fill: red").unwrap();
180 let injected = out.find("fill: green").unwrap();
181 assert!(
182 existing < injected,
183 "injected CSS should follow Mermaid CSS for cascade order: {out}"
184 );
185 }
186
187 #[test]
188 fn scoped_css_can_strip_existing_important_before_injection() {
189 let svg = r#"<svg id="diagram"><style>.node{fill:red !important;}</style></svg>"#;
190 let out = SvgPipeline::parity()
191 .with_postprocessor(
192 ScopedCssPostprocessor::new(".node { fill: green; }")
193 .with_override_policy(CssOverridePolicy::StripExistingImportant),
194 )
195 .process_to_string(svg)
196 .unwrap();
197
198 assert!(!out.contains("!important"));
199 assert!(out.contains("#diagram .node { fill: green; }"));
200 }
201}