Skip to main content

merman_render/
math.rs

1//! Optional math rendering hooks.
2//!
3//! Upstream Mermaid renders `$$...$$` fragments via KaTeX and measures the resulting HTML in a
4//! browser DOM. merman is headless and pure-Rust by default, so math rendering is modeled as an
5//! optional, pluggable backend.
6//!
7//! The default implementation is a no-op. For parity work, an optional Node.js-backed KaTeX
8//! renderer is also provided.
9
10use crate::text::{TextMetrics, TextStyle, WrapMode};
11use merman_core::MermaidConfig;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::fmt::Write as _;
15use std::io::Write as _;
16use std::path::{Path, PathBuf};
17use std::process::{Command, Stdio};
18use std::sync::Mutex;
19
20/// Optional math renderer used to transform label HTML and (optionally) provide measurements.
21///
22/// Implementations should be:
23/// - deterministic (stable output across runs),
24/// - side-effect free (no global mutations),
25/// - non-panicking (return `None` to decline handling).
26pub trait MathRenderer: std::fmt::Debug {
27    /// Attempts to render math fragments within an HTML label string.
28    ///
29    /// If the renderer declines to handle the input, it should return `None`.
30    ///
31    /// The returned string is treated as raw HTML and will still be sanitized by merman before
32    /// emitting into an SVG `<foreignObject>`.
33    fn render_html_label(&self, text: &str, config: &MermaidConfig) -> Option<String>;
34
35    /// Optionally measures the rendered HTML label in pixels.
36    ///
37    /// This is intended to mirror upstream Mermaid's DOM measurement behavior for math labels.
38    /// The default implementation returns `None`.
39    fn measure_html_label(
40        &self,
41        _text: &str,
42        _config: &MermaidConfig,
43        _style: &TextStyle,
44        _max_width_px: Option<f64>,
45        _wrap_mode: WrapMode,
46    ) -> Option<TextMetrics> {
47        None
48    }
49
50    /// Optionally measures a Sequence `drawKatex(...)` label in pixels.
51    ///
52    /// Mermaid Sequence does not wrap KaTeX labels in the flowchart HTML-label shell; it appends
53    /// a bare `<foreignObject><div style="width: fit-content;">...</div></foreignObject>`.
54    /// This hook lets Sequence callers avoid inheriting flowchart-specific table-cell metrics.
55    fn measure_sequence_html_label(
56        &self,
57        _text: &str,
58        _config: &MermaidConfig,
59    ) -> Option<TextMetrics> {
60        None
61    }
62}
63
64/// Default math renderer: does nothing.
65#[derive(Debug, Default, Clone, Copy)]
66pub struct NoopMathRenderer;
67
68impl MathRenderer for NoopMathRenderer {
69    fn render_html_label(&self, _text: &str, _config: &MermaidConfig) -> Option<String> {
70        None
71    }
72}
73
74#[derive(Debug, Clone, PartialEq, Eq, Hash)]
75struct RenderCacheKey {
76    text: String,
77    legacy_mathml: bool,
78    force_legacy_mathml: bool,
79}
80
81#[derive(Debug, Clone, PartialEq, Eq, Hash)]
82struct ProbeCacheKey {
83    render: RenderCacheKey,
84    font_family: Option<String>,
85    font_size_bits: u64,
86    font_weight: Option<String>,
87    max_width_bits: u64,
88}
89
90#[derive(Debug, Clone)]
91struct ProbeCacheValue {
92    html: String,
93    width: f64,
94    height: f64,
95    line_count: usize,
96}
97
98#[derive(Debug, Serialize)]
99struct NodeRenderRequest {
100    text: String,
101    config: NodeMathConfig,
102}
103
104#[derive(Debug, Serialize)]
105struct NodeProbeRequest {
106    text: String,
107    config: NodeMathConfig,
108    #[serde(rename = "styleCss")]
109    style_css: String,
110    #[serde(rename = "maxWidthPx")]
111    max_width_px: f64,
112}
113
114#[derive(Debug, Serialize)]
115struct NodeMathConfig {
116    #[serde(rename = "legacyMathML")]
117    legacy_mathml: bool,
118    #[serde(rename = "forceLegacyMathML")]
119    force_legacy_mathml: bool,
120}
121
122#[derive(Debug, Deserialize)]
123struct NodeRenderResponse {
124    html: String,
125}
126
127#[derive(Debug, Deserialize)]
128struct NodeProbeResponse {
129    html: String,
130    width: f64,
131    height: f64,
132}
133
134/// Optional KaTeX backend that shells out to a local Node.js toolchain.
135///
136/// This backend is intended for parity work where a real browser DOM is available. It mirrors
137/// Mermaid's flowchart HTML-label KaTeX path closely by:
138/// - rendering KaTeX through the local `katex` npm package, and
139/// - measuring the wrapped `<foreignObject>` HTML through local `puppeteer`.
140///
141/// The backend is completely opt-in; if the configured Node.js environment is unavailable or the
142/// probe fails, it simply returns `None` and lets callers fall back to the default text path.
143#[derive(Debug)]
144pub struct NodeKatexMathRenderer {
145    node_cwd: PathBuf,
146    node_command: PathBuf,
147    render_cache: Mutex<HashMap<RenderCacheKey, Option<String>>>,
148    probe_cache: Mutex<HashMap<ProbeCacheKey, Option<ProbeCacheValue>>>,
149    sequence_probe_cache: Mutex<HashMap<RenderCacheKey, Option<ProbeCacheValue>>>,
150}
151
152impl NodeKatexMathRenderer {
153    pub fn new(node_cwd: impl Into<PathBuf>) -> Self {
154        Self {
155            node_cwd: node_cwd.into(),
156            node_command: PathBuf::from("node"),
157            render_cache: Mutex::new(HashMap::new()),
158            probe_cache: Mutex::new(HashMap::new()),
159            sequence_probe_cache: Mutex::new(HashMap::new()),
160        }
161    }
162
163    pub fn with_node_command(mut self, node_command: impl Into<PathBuf>) -> Self {
164        self.node_command = node_command.into();
165        self
166    }
167
168    fn script_path() -> PathBuf {
169        Path::new(env!("CARGO_MANIFEST_DIR"))
170            .join("assets")
171            .join("katex_flowchart_probe.cjs")
172    }
173
174    fn normalized_text(text: &str) -> String {
175        text.replace("\\\\", "\\")
176    }
177
178    fn math_config(config: &MermaidConfig) -> NodeMathConfig {
179        let config_value = config.as_value();
180        let legacy_mathml = config_value
181            .get("legacyMathML")
182            .and_then(serde_json::Value::as_bool)
183            .unwrap_or(false);
184        let force_legacy_mathml = config_value
185            .get("forceLegacyMathML")
186            .and_then(serde_json::Value::as_bool)
187            .unwrap_or(false);
188        NodeMathConfig {
189            legacy_mathml,
190            force_legacy_mathml,
191        }
192    }
193
194    fn render_key(text: &str, config: &MermaidConfig) -> RenderCacheKey {
195        let config = Self::math_config(config);
196        RenderCacheKey {
197            text: Self::normalized_text(text),
198            legacy_mathml: config.legacy_mathml,
199            force_legacy_mathml: config.force_legacy_mathml,
200        }
201    }
202
203    fn style_css(style: &TextStyle) -> String {
204        let mut out = String::new();
205        let font_family = style
206            .font_family
207            .as_deref()
208            .unwrap_or("\"trebuchet ms\",verdana,arial,sans-serif");
209        let _ = write!(&mut out, "font-size: {}px;", style.font_size);
210        let _ = write!(&mut out, "font-family: {};", font_family);
211        if let Some(font_weight) = style.font_weight.as_deref() {
212            if !font_weight.trim().is_empty() {
213                let _ = write!(&mut out, "font-weight: {};", font_weight.trim());
214            }
215        }
216        out
217    }
218
219    fn run_node_request<T, R>(&self, mode: &str, payload: &T) -> Option<R>
220    where
221        T: Serialize,
222        R: for<'de> Deserialize<'de>,
223    {
224        if !self.node_cwd.join("package.json").is_file() {
225            return None;
226        }
227
228        let mut child = Command::new(&self.node_command)
229            .arg(Self::script_path())
230            .arg(mode)
231            .current_dir(&self.node_cwd)
232            .stdin(Stdio::piped())
233            .stdout(Stdio::piped())
234            .stderr(Stdio::piped())
235            .spawn()
236            .ok()?;
237
238        if let Some(mut stdin) = child.stdin.take() {
239            if serde_json::to_writer(&mut stdin, payload).is_err() {
240                return None;
241            }
242            let _ = stdin.flush();
243        }
244
245        let output = child.wait_with_output().ok()?;
246        if !output.status.success() {
247            return None;
248        }
249
250        serde_json::from_slice(&output.stdout).ok()
251    }
252
253    fn render_cached(&self, text: &str, config: &MermaidConfig) -> Option<String> {
254        let key = Self::render_key(text, config);
255        if let Some(cached) = self
256            .render_cache
257            .lock()
258            .ok()
259            .and_then(|cache| cache.get(&key).cloned())
260        {
261            return cached;
262        }
263
264        let response: Option<NodeRenderResponse> = self.run_node_request(
265            "render",
266            &NodeRenderRequest {
267                text: key.text.clone(),
268                config: NodeMathConfig {
269                    legacy_mathml: key.legacy_mathml,
270                    force_legacy_mathml: key.force_legacy_mathml,
271                },
272            },
273        );
274        let html = response.map(|value| value.html);
275
276        if let Ok(mut cache) = self.render_cache.lock() {
277            cache.insert(key, html.clone());
278        }
279
280        html
281    }
282
283    fn probe_cached(
284        &self,
285        text: &str,
286        config: &MermaidConfig,
287        style: &TextStyle,
288        max_width_px: Option<f64>,
289        _wrap_mode: WrapMode,
290    ) -> Option<ProbeCacheValue> {
291        let render = Self::render_key(text, config);
292        let max_width = max_width_px.unwrap_or(200.0).max(1.0);
293        let key = ProbeCacheKey {
294            render: render.clone(),
295            font_family: style.font_family.clone(),
296            font_size_bits: style.font_size.to_bits(),
297            font_weight: style.font_weight.clone(),
298            max_width_bits: max_width.to_bits(),
299        };
300        if let Some(cached) = self
301            .probe_cache
302            .lock()
303            .ok()
304            .and_then(|cache| cache.get(&key).cloned())
305        {
306            return cached;
307        }
308
309        let style_css = Self::style_css(style);
310        let response: Option<NodeProbeResponse> = self.run_node_request(
311            "probe",
312            &NodeProbeRequest {
313                text: render.text.clone(),
314                config: NodeMathConfig {
315                    legacy_mathml: render.legacy_mathml,
316                    force_legacy_mathml: render.force_legacy_mathml,
317                },
318                style_css,
319                max_width_px: max_width,
320            },
321        );
322        let probed = response.and_then(|value| {
323            if !value.width.is_finite() || !value.height.is_finite() {
324                return None;
325            }
326            let line_count = value.html.match_indices("<div").count().max(1);
327            Some(ProbeCacheValue {
328                html: value.html,
329                width: value.width.max(0.0),
330                height: value.height.max(0.0),
331                line_count,
332            })
333        });
334
335        if let Some(probed_value) = probed.clone() {
336            if let Ok(mut render_cache) = self.render_cache.lock() {
337                render_cache
338                    .entry(render)
339                    .or_insert_with(|| Some(probed_value.html.clone()));
340            }
341        }
342        if let Ok(mut cache) = self.probe_cache.lock() {
343            cache.insert(key, probed.clone());
344        }
345
346        probed
347    }
348
349    fn sequence_probe_cached(&self, text: &str, config: &MermaidConfig) -> Option<ProbeCacheValue> {
350        let key = Self::render_key(text, config);
351        if let Some(cached) = self
352            .sequence_probe_cache
353            .lock()
354            .ok()
355            .and_then(|cache| cache.get(&key).cloned())
356        {
357            return cached;
358        }
359
360        let response: Option<NodeProbeResponse> = self.run_node_request(
361            "probe-sequence",
362            &NodeRenderRequest {
363                text: key.text.clone(),
364                config: NodeMathConfig {
365                    legacy_mathml: key.legacy_mathml,
366                    force_legacy_mathml: key.force_legacy_mathml,
367                },
368            },
369        );
370        let probed = response.and_then(|value| {
371            if !value.width.is_finite() || !value.height.is_finite() {
372                return None;
373            }
374            let line_count = value.html.match_indices("<div").count().max(1);
375            Some(ProbeCacheValue {
376                html: value.html,
377                width: value.width.max(0.0),
378                height: value.height.max(0.0),
379                line_count,
380            })
381        });
382
383        if let Some(probed_value) = probed.clone() {
384            if let Ok(mut render_cache) = self.render_cache.lock() {
385                render_cache
386                    .entry(key.clone())
387                    .or_insert_with(|| Some(probed_value.html.clone()));
388            }
389        }
390        if let Ok(mut cache) = self.sequence_probe_cache.lock() {
391            cache.insert(key, probed.clone());
392        }
393
394        probed
395    }
396}
397
398impl MathRenderer for NodeKatexMathRenderer {
399    fn render_html_label(&self, text: &str, config: &MermaidConfig) -> Option<String> {
400        if !text.contains("$$") {
401            return None;
402        }
403        self.render_cached(text, config)
404    }
405
406    fn measure_html_label(
407        &self,
408        text: &str,
409        config: &MermaidConfig,
410        style: &TextStyle,
411        max_width_px: Option<f64>,
412        wrap_mode: WrapMode,
413    ) -> Option<TextMetrics> {
414        if wrap_mode != WrapMode::HtmlLike || !text.contains("$$") {
415            return None;
416        }
417        let probed = self.probe_cached(text, config, style, max_width_px, wrap_mode)?;
418        Some(TextMetrics {
419            width: crate::text::round_to_1_64_px(probed.width),
420            height: crate::text::round_to_1_64_px(probed.height),
421            line_count: probed.line_count,
422        })
423    }
424
425    fn measure_sequence_html_label(
426        &self,
427        text: &str,
428        config: &MermaidConfig,
429    ) -> Option<TextMetrics> {
430        if !text.contains("$$") {
431            return None;
432        }
433        let probed = self.sequence_probe_cached(text, config)?;
434        Some(TextMetrics {
435            width: crate::text::round_to_1_64_px(probed.width),
436            height: crate::text::round_to_1_64_px(probed.height),
437            line_count: probed.line_count,
438        })
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445
446    #[test]
447    fn node_katex_math_renderer_smoke() {
448        let node_cwd = Path::new(env!("CARGO_MANIFEST_DIR"))
449            .join("..")
450            .join("..")
451            .join("tools")
452            .join("mermaid-cli");
453        if !node_cwd.join("package.json").is_file() || !node_cwd.join("node_modules").is_dir() {
454            return;
455        }
456
457        let renderer = NodeKatexMathRenderer::new(node_cwd);
458        let config = MermaidConfig::default();
459        let style = TextStyle::default();
460
461        let Some(html) = renderer.render_html_label("$$x^2$$", &config) else {
462            return;
463        };
464        assert!(html.contains("katex"), "unexpected HTML: {html}");
465
466        let Some(metrics) = renderer.measure_html_label(
467            "$$x^2$$",
468            &config,
469            &style,
470            Some(200.0),
471            WrapMode::HtmlLike,
472        ) else {
473            return;
474        };
475        assert!(metrics.width.is_finite() && metrics.width > 0.0);
476        assert!(metrics.height.is_finite() && metrics.height > 0.0);
477    }
478
479    #[test]
480    fn node_katex_math_renderer_measures_sanitized_flowchart_browser_shell() {
481        let node_cwd = Path::new(env!("CARGO_MANIFEST_DIR"))
482            .join("..")
483            .join("..")
484            .join("tools")
485            .join("mermaid-cli");
486        if !node_cwd.join("package.json").is_file() || !node_cwd.join("node_modules").is_dir() {
487            return;
488        }
489
490        let renderer = NodeKatexMathRenderer::new(node_cwd);
491        let config = MermaidConfig::default();
492        let style = TextStyle::default();
493
494        let long_integral = "$$f(\\relax{x}) = \\int_{-\\infty}^\\infty \\hat{f}(\\xi)\\,e^{2 \\pi i \\xi x}\\,d\\xi$$";
495        let Some(node_metrics) = renderer.measure_html_label(
496            long_integral,
497            &config,
498            &style,
499            Some(200.0),
500            WrapMode::HtmlLike,
501        ) else {
502            return;
503        };
504        assert!(
505            (180.0..=260.0).contains(&node_metrics.width),
506            "node width = {}",
507            node_metrics.width
508        );
509        assert!(
510            (30.0..=70.0).contains(&node_metrics.height),
511            "node height = {}",
512            node_metrics.height
513        );
514        let Some(html) = renderer.render_html_label(long_integral, &config) else {
515            panic!("expected rendered math HTML after successful probe");
516        };
517        assert!(html.contains("<math"), "unexpected HTML: {html}");
518        assert!(!html.contains("<semantics>"), "unsanitized HTML: {html}");
519
520        let nested_delimiters = "$$\\Bigg(\\bigg(\\Big(\\big((\\frac{-b\\pm\\sqrt{b^2-4ac}}{2a})\\big)\\Big)\\bigg)\\Bigg)$$";
521        let Some(edge_metrics) = renderer.measure_html_label(
522            nested_delimiters,
523            &config,
524            &style,
525            Some(200.0),
526            WrapMode::HtmlLike,
527        ) else {
528            return;
529        };
530        assert!(
531            (180.0..=320.0).contains(&edge_metrics.width),
532            "edge width = {}",
533            edge_metrics.width
534        );
535        assert!(
536            (40.0..=100.0).contains(&edge_metrics.height),
537            "edge height = {}",
538            edge_metrics.height
539        );
540    }
541}