Skip to main content

merman_render/svg/pipeline/
mod.rs

1mod builtin;
2mod context;
3mod preset;
4
5pub use builtin::{
6    CssOverridePolicy, CssOverridePostprocessor, ForeignObjectFallbackPostprocessor,
7    SanitizeCssPostprocessor, SanitizeSvgAttributesPostprocessor, ScopedCssPostprocessor,
8    StripForeignObjectPostprocessor,
9};
10pub use context::{SvgPostprocessContext, SvgPostprocessMetadata};
11pub use preset::{SvgPipelinePreset, resvg_safe_svg};
12
13use crate::{Error, Result};
14use std::borrow::Cow;
15use std::fmt;
16use std::sync::Arc;
17
18pub trait SvgPostprocessor: Send + Sync {
19    fn name(&self) -> &'static str;
20
21    fn process<'a>(
22        &self,
23        svg: Cow<'a, str>,
24        ctx: &SvgPostprocessContext<'_>,
25    ) -> Result<Cow<'a, str>>;
26}
27
28#[derive(Clone)]
29pub struct SvgPipeline {
30    preset: SvgPipelinePreset,
31    postprocessors: Vec<Arc<dyn SvgPostprocessor>>,
32}
33
34impl fmt::Debug for SvgPipeline {
35    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36        let names = self
37            .postprocessors
38            .iter()
39            .map(|pass| pass.name())
40            .collect::<Vec<_>>();
41
42        f.debug_struct("SvgPipeline")
43            .field("preset", &self.preset)
44            .field("postprocessors", &names)
45            .finish()
46    }
47}
48
49impl Default for SvgPipeline {
50    fn default() -> Self {
51        Self::parity()
52    }
53}
54
55impl SvgPipeline {
56    pub fn parity() -> Self {
57        Self::from_preset(SvgPipelinePreset::Parity)
58    }
59
60    pub fn readable() -> Self {
61        Self::from_preset(SvgPipelinePreset::Readable)
62    }
63
64    pub fn resvg_safe() -> Self {
65        Self::from_preset(SvgPipelinePreset::ResvgSafe)
66    }
67
68    pub fn from_preset(preset: SvgPipelinePreset) -> Self {
69        Self {
70            preset,
71            postprocessors: Vec::new(),
72        }
73    }
74
75    pub fn preset(&self) -> SvgPipelinePreset {
76        self.preset
77    }
78
79    pub fn with_postprocessor<P>(mut self, postprocessor: P) -> Self
80    where
81        P: SvgPostprocessor + 'static,
82    {
83        self.postprocessors.push(Arc::new(postprocessor));
84        self
85    }
86
87    pub fn with_shared_postprocessor(mut self, postprocessor: Arc<dyn SvgPostprocessor>) -> Self {
88        self.postprocessors.push(postprocessor);
89        self
90    }
91
92    pub fn push_postprocessor<P>(&mut self, postprocessor: P)
93    where
94        P: SvgPostprocessor + 'static,
95    {
96        self.postprocessors.push(Arc::new(postprocessor));
97    }
98
99    pub fn process<'a>(&self, svg: &'a str) -> Result<Cow<'a, str>> {
100        let metadata = SvgPostprocessMetadata::from_svg(svg);
101        self.process_with_metadata(svg, &metadata)
102    }
103
104    pub fn process_with_metadata<'a>(
105        &self,
106        svg: &'a str,
107        metadata: &SvgPostprocessMetadata,
108    ) -> Result<Cow<'a, str>> {
109        let mut current = preset::apply_preset(self.preset, svg);
110
111        for (index, postprocessor) in self.postprocessors.iter().enumerate() {
112            let ctx =
113                SvgPostprocessContext::new(self.preset, index, postprocessor.name(), metadata);
114            current = postprocessor
115                .process(current, &ctx)
116                .map_err(|err| Error::svg_postprocess(postprocessor.name(), err.to_string()))?;
117        }
118
119        Ok(current)
120    }
121
122    pub fn process_to_string(&self, svg: &str) -> Result<String> {
123        Ok(self.process(svg)?.into_owned())
124    }
125
126    pub fn process_to_string_with_metadata(
127        &self,
128        svg: &str,
129        metadata: &SvgPostprocessMetadata,
130    ) -> Result<String> {
131        Ok(self.process_with_metadata(svg, metadata)?.into_owned())
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138
139    #[test]
140    fn parity_pipeline_preserves_svg_exactly() {
141        let svg = r#"<svg><style>@keyframes a{to{opacity:1}}</style><rect width="10"/></svg>"#;
142        let out = SvgPipeline::parity().process(svg).unwrap();
143        assert!(matches!(out, Cow::Borrowed(_)));
144        assert_eq!(out, svg);
145    }
146
147    #[test]
148    fn readable_pipeline_matches_foreign_object_fallback() {
149        let svg = r#"<svg xmlns="http://www.w3.org/2000/svg"><g transform="translate(10,20)"><foreignObject width="80" height="48"><div xmlns="http://www.w3.org/1999/xhtml"><p>Layer 7\nHTTP</p></div></foreignObject></g></svg>"#;
150
151        let expected = super::builtin::foreign_object::foreign_object_fallback_svg(svg);
152        let out = SvgPipeline::readable().process_to_string(svg).unwrap();
153
154        assert_eq!(out, expected);
155        assert!(out.contains(">Layer 7</text>"));
156        assert!(out.contains(">HTTP</text>"));
157    }
158
159    #[test]
160    fn resvg_safe_pipeline_strips_generic_raster_hazards() {
161        let svg = r#"<svg id="test" xmlns="http://www.w3.org/2000/svg"><style type="text/css">@keyframes bounce { 0% { transform: scale(1); } 100% { transform: scale(1.1); } } #test :root { --bg: white; } .node rect { animation: dash 1s linear; transform: rotate(45deg); fill: red; }</style><g transform="translate(undefined,NaN)"><foreignObject width="10" height="10"><div xmlns="http://www.w3.org/1999/xhtml"><p>Hello</p></div></foreignObject><rect width="10px" height="12px" stroke="" style="fill: ; stroke: #333; transform: rotate(45deg); animation: dash 1s;"/><rect width="10px" height="" fill="hsl(240, 100%, NaN%)"/></g></svg>"#;
162
163        let out = SvgPipeline::resvg_safe().process_to_string(svg).unwrap();
164
165        assert!(!out.contains("<foreignObject"));
166        assert!(!out.contains("@keyframes"));
167        assert!(!out.contains(":root"));
168        assert!(!out.contains("animation"));
169        assert!(!out.contains("deg"));
170        assert!(!out.contains("NaN"));
171        assert!(!out.contains("undefined"));
172        assert!(!out.contains(r#"height="""#));
173        assert!(!out.contains(r#"fill="hsl"#));
174        assert!(!out.contains(r#"stroke="""#));
175        assert!(out.contains(r#"width="10""#));
176        assert!(out.contains(r#"height="12""#));
177        assert!(out.contains("stroke:#333"));
178        assert!(out.contains(">Hello</text>"));
179    }
180
181    struct AppendPass(&'static str);
182
183    impl SvgPostprocessor for AppendPass {
184        fn name(&self) -> &'static str {
185            self.0
186        }
187
188        fn process<'a>(
189            &self,
190            svg: Cow<'a, str>,
191            ctx: &SvgPostprocessContext<'_>,
192        ) -> Result<Cow<'a, str>> {
193            Ok(Cow::Owned(format!(
194                "{}<!--{}:{}:{:?}:{}:{}:{}-->",
195                svg,
196                ctx.pass_index(),
197                ctx.pass_name(),
198                ctx.preset(),
199                ctx.diagram_type().unwrap_or("none"),
200                ctx.diagram_title().unwrap_or("none"),
201                ctx.svg_id().unwrap_or("none")
202            )))
203        }
204    }
205
206    #[test]
207    fn custom_postprocessors_run_after_builtin_preset_in_order() {
208        let svg = r#"<svg><foreignObject width="10" height="10"><div><p>Hello</p></div></foreignObject></svg>"#;
209        let pipeline = SvgPipeline::readable()
210            .with_postprocessor(AppendPass("first"))
211            .with_postprocessor(AppendPass("second"));
212
213        let out = pipeline.process_to_string(svg).unwrap();
214
215        let fallback = out.find("data-merman-foreignobject").unwrap();
216        let first = out.find("<!--0:first:Readable").unwrap();
217        let second = out.find("<!--1:second:Readable").unwrap();
218        assert!(fallback < first);
219        assert!(first < second);
220    }
221
222    #[test]
223    fn custom_postprocessor_context_exposes_metadata() {
224        let svg = r#"<svg id="host-diagram"><rect width="10"/></svg>"#;
225        let metadata = SvgPostprocessMetadata::from_svg(svg)
226            .with_diagram_type("flowchart-v2")
227            .with_diagram_title("Host Diagram");
228        let pipeline = SvgPipeline::parity().with_postprocessor(AppendPass("meta"));
229
230        let out = pipeline
231            .process_to_string_with_metadata(svg, &metadata)
232            .unwrap();
233
234        assert!(out.contains("<!--0:meta:Parity:flowchart-v2:Host Diagram:host-diagram-->"));
235    }
236
237    struct ErrorPass;
238
239    impl SvgPostprocessor for ErrorPass {
240        fn name(&self) -> &'static str {
241            "error-pass"
242        }
243
244        fn process<'a>(
245            &self,
246            _svg: Cow<'a, str>,
247            _ctx: &SvgPostprocessContext<'_>,
248        ) -> Result<Cow<'a, str>> {
249            Err(Error::InvalidModel {
250                message: "boom".to_string(),
251            })
252        }
253    }
254
255    #[test]
256    fn custom_postprocessor_errors_surface_with_pass_name() {
257        let err = SvgPipeline::parity()
258            .with_postprocessor(ErrorPass)
259            .process_to_string("<svg/>")
260            .unwrap_err();
261
262        let message = err.to_string();
263        assert!(message.contains("error-pass"));
264        assert!(message.contains("boom"));
265    }
266}