Skip to main content

merman_render/
math.rs

1//! Optional math rendering hooks.
2//!
3//! Upstream Mermaid renders `$$...$$` fragments via KaTeX and measures the resulting HTML in a
4//! browser DOM. merman is headless and pure-Rust by default, so math rendering is modeled as an
5//! optional, pluggable backend.
6//!
7//! The default implementation is a no-op. For parity work, an optional Node.js-backed KaTeX
8//! renderer is also provided.
9
10use crate::text::{TextMetrics, TextStyle, WrapMode};
11use merman_core::MermaidConfig;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::fmt::Write as _;
15use std::io::Write as _;
16use std::path::{Path, PathBuf};
17use std::process::{Command, Stdio};
18use std::sync::Mutex;
19
20/// Optional math renderer used to transform label HTML and (optionally) provide measurements.
21///
22/// Implementations should be:
23/// - deterministic (stable output across runs),
24/// - side-effect free (no global mutations),
25/// - non-panicking (return `None` to decline handling).
26pub trait MathRenderer: std::fmt::Debug {
27    /// Attempts to render math fragments within an HTML label string.
28    ///
29    /// If the renderer declines to handle the input, it should return `None`.
30    ///
31    /// The returned string is treated as raw HTML and will still be sanitized by merman before
32    /// emitting into an SVG `<foreignObject>`.
33    fn render_html_label(&self, text: &str, config: &MermaidConfig) -> Option<String>;
34
35    /// Optionally measures the rendered HTML label in pixels.
36    ///
37    /// This is intended to mirror upstream Mermaid's DOM measurement behavior for math labels.
38    /// The default implementation returns `None`.
39    fn measure_html_label(
40        &self,
41        _text: &str,
42        _config: &MermaidConfig,
43        _style: &TextStyle,
44        _max_width_px: Option<f64>,
45        _wrap_mode: WrapMode,
46    ) -> Option<TextMetrics> {
47        None
48    }
49}
50
51/// Default math renderer: does nothing.
52#[derive(Debug, Default, Clone, Copy)]
53pub struct NoopMathRenderer;
54
55impl MathRenderer for NoopMathRenderer {
56    fn render_html_label(&self, _text: &str, _config: &MermaidConfig) -> Option<String> {
57        None
58    }
59}
60
61#[derive(Debug, Clone, PartialEq, Eq, Hash)]
62struct RenderCacheKey {
63    text: String,
64    legacy_mathml: bool,
65    force_legacy_mathml: bool,
66}
67
68#[derive(Debug, Clone, PartialEq, Eq, Hash)]
69struct ProbeCacheKey {
70    render: RenderCacheKey,
71    font_family: Option<String>,
72    font_size_bits: u64,
73    font_weight: Option<String>,
74    max_width_bits: u64,
75}
76
77#[derive(Debug, Clone)]
78struct ProbeCacheValue {
79    html: String,
80    width: f64,
81    height: f64,
82    line_count: usize,
83}
84
85#[derive(Debug, Serialize)]
86struct NodeRenderRequest {
87    text: String,
88    config: NodeMathConfig,
89}
90
91#[derive(Debug, Serialize)]
92struct NodeProbeRequest {
93    text: String,
94    config: NodeMathConfig,
95    #[serde(rename = "styleCss")]
96    style_css: String,
97    #[serde(rename = "maxWidthPx")]
98    max_width_px: f64,
99}
100
101#[derive(Debug, Serialize)]
102struct NodeMathConfig {
103    #[serde(rename = "legacyMathML")]
104    legacy_mathml: bool,
105    #[serde(rename = "forceLegacyMathML")]
106    force_legacy_mathml: bool,
107}
108
109#[derive(Debug, Deserialize)]
110struct NodeRenderResponse {
111    html: String,
112}
113
114#[derive(Debug, Deserialize)]
115struct NodeProbeResponse {
116    html: String,
117    width: f64,
118    height: f64,
119}
120
121/// Optional KaTeX backend that shells out to a local Node.js toolchain.
122///
123/// This backend is intended for parity work where a real browser DOM is available. It mirrors
124/// Mermaid's flowchart HTML-label KaTeX path closely by:
125/// - rendering KaTeX through the local `katex` npm package, and
126/// - measuring the wrapped `<foreignObject>` HTML through local `puppeteer`.
127///
128/// The backend is completely opt-in; if the configured Node.js environment is unavailable or the
129/// probe fails, it simply returns `None` and lets callers fall back to the default text path.
130#[derive(Debug)]
131pub struct NodeKatexMathRenderer {
132    node_cwd: PathBuf,
133    node_command: PathBuf,
134    render_cache: Mutex<HashMap<RenderCacheKey, Option<String>>>,
135    probe_cache: Mutex<HashMap<ProbeCacheKey, Option<ProbeCacheValue>>>,
136}
137
138impl NodeKatexMathRenderer {
139    pub fn new(node_cwd: impl Into<PathBuf>) -> Self {
140        Self {
141            node_cwd: node_cwd.into(),
142            node_command: PathBuf::from("node"),
143            render_cache: Mutex::new(HashMap::new()),
144            probe_cache: Mutex::new(HashMap::new()),
145        }
146    }
147
148    pub fn with_node_command(mut self, node_command: impl Into<PathBuf>) -> Self {
149        self.node_command = node_command.into();
150        self
151    }
152
153    fn script_path() -> PathBuf {
154        Path::new(env!("CARGO_MANIFEST_DIR"))
155            .join("assets")
156            .join("katex_flowchart_probe.cjs")
157    }
158
159    fn normalized_text(text: &str) -> String {
160        text.replace("\\\\", "\\")
161    }
162
163    fn math_config(config: &MermaidConfig) -> NodeMathConfig {
164        let config_value = config.as_value();
165        let legacy_mathml = config_value
166            .get("legacyMathML")
167            .and_then(serde_json::Value::as_bool)
168            .unwrap_or(false);
169        let force_legacy_mathml = config_value
170            .get("forceLegacyMathML")
171            .and_then(serde_json::Value::as_bool)
172            .unwrap_or(false);
173        NodeMathConfig {
174            legacy_mathml,
175            force_legacy_mathml,
176        }
177    }
178
179    fn render_key(text: &str, config: &MermaidConfig) -> RenderCacheKey {
180        let config = Self::math_config(config);
181        RenderCacheKey {
182            text: Self::normalized_text(text),
183            legacy_mathml: config.legacy_mathml,
184            force_legacy_mathml: config.force_legacy_mathml,
185        }
186    }
187
188    fn style_css(style: &TextStyle) -> String {
189        let mut out = String::new();
190        let font_family = style
191            .font_family
192            .as_deref()
193            .unwrap_or("\"trebuchet ms\",verdana,arial,sans-serif");
194        let _ = write!(&mut out, "font-size: {}px;", style.font_size);
195        let _ = write!(&mut out, "font-family: {};", font_family);
196        if let Some(font_weight) = style.font_weight.as_deref() {
197            if !font_weight.trim().is_empty() {
198                let _ = write!(&mut out, "font-weight: {};", font_weight.trim());
199            }
200        }
201        out
202    }
203
204    fn run_node_request<T, R>(&self, mode: &str, payload: &T) -> Option<R>
205    where
206        T: Serialize,
207        R: for<'de> Deserialize<'de>,
208    {
209        if !self.node_cwd.join("package.json").is_file() {
210            return None;
211        }
212
213        let mut child = Command::new(&self.node_command)
214            .arg(Self::script_path())
215            .arg(mode)
216            .current_dir(&self.node_cwd)
217            .stdin(Stdio::piped())
218            .stdout(Stdio::piped())
219            .stderr(Stdio::piped())
220            .spawn()
221            .ok()?;
222
223        if let Some(mut stdin) = child.stdin.take() {
224            if serde_json::to_writer(&mut stdin, payload).is_err() {
225                return None;
226            }
227            let _ = stdin.flush();
228        }
229
230        let output = child.wait_with_output().ok()?;
231        if !output.status.success() {
232            return None;
233        }
234
235        serde_json::from_slice(&output.stdout).ok()
236    }
237
238    fn render_cached(&self, text: &str, config: &MermaidConfig) -> Option<String> {
239        let key = Self::render_key(text, config);
240        if let Some(cached) = self
241            .render_cache
242            .lock()
243            .ok()
244            .and_then(|cache| cache.get(&key).cloned())
245        {
246            return cached;
247        }
248
249        let response: Option<NodeRenderResponse> = self.run_node_request(
250            "render",
251            &NodeRenderRequest {
252                text: key.text.clone(),
253                config: NodeMathConfig {
254                    legacy_mathml: key.legacy_mathml,
255                    force_legacy_mathml: key.force_legacy_mathml,
256                },
257            },
258        );
259        let html = response.map(|value| value.html);
260
261        if let Ok(mut cache) = self.render_cache.lock() {
262            cache.insert(key, html.clone());
263        }
264
265        html
266    }
267
268    fn probe_cached(
269        &self,
270        text: &str,
271        config: &MermaidConfig,
272        style: &TextStyle,
273        max_width_px: Option<f64>,
274        _wrap_mode: WrapMode,
275    ) -> Option<ProbeCacheValue> {
276        let render = Self::render_key(text, config);
277        let max_width = max_width_px.unwrap_or(200.0).max(1.0);
278        let key = ProbeCacheKey {
279            render: render.clone(),
280            font_family: style.font_family.clone(),
281            font_size_bits: style.font_size.to_bits(),
282            font_weight: style.font_weight.clone(),
283            max_width_bits: max_width.to_bits(),
284        };
285        if let Some(cached) = self
286            .probe_cache
287            .lock()
288            .ok()
289            .and_then(|cache| cache.get(&key).cloned())
290        {
291            return cached;
292        }
293
294        let style_css = Self::style_css(style);
295        let response: Option<NodeProbeResponse> = self.run_node_request(
296            "probe",
297            &NodeProbeRequest {
298                text: render.text.clone(),
299                config: NodeMathConfig {
300                    legacy_mathml: render.legacy_mathml,
301                    force_legacy_mathml: render.force_legacy_mathml,
302                },
303                style_css,
304                max_width_px: max_width,
305            },
306        );
307        let probed = response.and_then(|value| {
308            if !value.width.is_finite() || !value.height.is_finite() {
309                return None;
310            }
311            let line_count = value.html.match_indices("<div").count().max(1);
312            Some(ProbeCacheValue {
313                html: value.html,
314                width: value.width.max(0.0),
315                height: value.height.max(0.0),
316                line_count,
317            })
318        });
319
320        if let Some(probed_value) = probed.clone() {
321            if let Ok(mut render_cache) = self.render_cache.lock() {
322                render_cache
323                    .entry(render)
324                    .or_insert_with(|| Some(probed_value.html.clone()));
325            }
326        }
327        if let Ok(mut cache) = self.probe_cache.lock() {
328            cache.insert(key, probed.clone());
329        }
330
331        probed
332    }
333}
334
335impl MathRenderer for NodeKatexMathRenderer {
336    fn render_html_label(&self, text: &str, config: &MermaidConfig) -> Option<String> {
337        if !text.contains("$$") {
338            return None;
339        }
340        self.render_cached(text, config)
341    }
342
343    fn measure_html_label(
344        &self,
345        text: &str,
346        config: &MermaidConfig,
347        style: &TextStyle,
348        max_width_px: Option<f64>,
349        wrap_mode: WrapMode,
350    ) -> Option<TextMetrics> {
351        if wrap_mode != WrapMode::HtmlLike || !text.contains("$$") {
352            return None;
353        }
354        let probed = self.probe_cached(text, config, style, max_width_px, wrap_mode)?;
355        Some(TextMetrics {
356            width: crate::text::round_to_1_64_px(probed.width),
357            height: crate::text::round_to_1_64_px(probed.height),
358            line_count: probed.line_count,
359        })
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366
367    #[test]
368    fn node_katex_math_renderer_smoke() {
369        let node_cwd = Path::new(env!("CARGO_MANIFEST_DIR"))
370            .join("..")
371            .join("..")
372            .join("tools")
373            .join("mermaid-cli");
374        if !node_cwd.join("package.json").is_file() || !node_cwd.join("node_modules").is_dir() {
375            return;
376        }
377
378        let renderer = NodeKatexMathRenderer::new(node_cwd);
379        let config = MermaidConfig::default();
380        let style = TextStyle::default();
381
382        let html = renderer
383            .render_html_label("$$x^2$$", &config)
384            .expect("node KaTeX renderer should produce HTML");
385        assert!(html.contains("katex"), "unexpected HTML: {html}");
386
387        let metrics = renderer
388            .measure_html_label("$$x^2$$", &config, &style, Some(200.0), WrapMode::HtmlLike)
389            .expect("node KaTeX renderer should produce metrics");
390        assert!(metrics.width.is_finite() && metrics.width > 0.0);
391        assert!(metrics.height.is_finite() && metrics.height > 0.0);
392    }
393
394    #[test]
395    fn node_katex_math_renderer_matches_flowchart_browser_shell_metrics() {
396        let node_cwd = Path::new(env!("CARGO_MANIFEST_DIR"))
397            .join("..")
398            .join("..")
399            .join("tools")
400            .join("mermaid-cli");
401        if !node_cwd.join("package.json").is_file() || !node_cwd.join("node_modules").is_dir() {
402            return;
403        }
404
405        let renderer = NodeKatexMathRenderer::new(node_cwd);
406        let config = MermaidConfig::default();
407        let style = TextStyle::default();
408
409        let node_metrics = renderer
410            .measure_html_label(
411                "$$f(\\relax{x}) = \\int_{-\\infty}^\\infty \\hat{f}(\\xi)\\,e^{2 \\pi i \\xi x}\\,d\\xi$$",
412                &config,
413                &style,
414                Some(200.0),
415                WrapMode::HtmlLike,
416            )
417            .expect("node label metrics");
418        assert!(
419            (node_metrics.width - 195.140625).abs() < 1e-9,
420            "node width = {}",
421            node_metrics.width
422        );
423        assert!(
424            (node_metrics.height - 27.53125).abs() < 1e-9,
425            "node height = {}",
426            node_metrics.height
427        );
428
429        let edge_metrics = renderer
430            .measure_html_label(
431                "$$\\Bigg(\\bigg(\\Big(\\big((\\frac{-b\\pm\\sqrt{b^2-4ac}}{2a})\\big)\\Big)\\bigg)\\Bigg)$$",
432                &config,
433                &style,
434                Some(200.0),
435                WrapMode::HtmlLike,
436            )
437            .expect("edge label metrics");
438        assert!(
439            (edge_metrics.width - 184.78125).abs() < 1e-9,
440            "edge width = {}",
441            edge_metrics.width
442        );
443        assert!(
444            (edge_metrics.height - 41.53125).abs() < 1e-9,
445            "edge height = {}",
446            edge_metrics.height
447        );
448    }
449}