1use crate::text::{TextMetrics, TextStyle, WrapMode};
11use merman_core::MermaidConfig;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::fmt::Write as _;
15use std::io::Write as _;
16use std::path::{Path, PathBuf};
17use std::process::{Command, Stdio};
18use std::sync::Mutex;
19
20pub trait MathRenderer: std::fmt::Debug {
27 fn render_html_label(&self, text: &str, config: &MermaidConfig) -> Option<String>;
34
35 fn measure_html_label(
40 &self,
41 _text: &str,
42 _config: &MermaidConfig,
43 _style: &TextStyle,
44 _max_width_px: Option<f64>,
45 _wrap_mode: WrapMode,
46 ) -> Option<TextMetrics> {
47 None
48 }
49
50 fn measure_sequence_html_label(
56 &self,
57 _text: &str,
58 _config: &MermaidConfig,
59 ) -> Option<TextMetrics> {
60 None
61 }
62}
63
64#[derive(Debug, Default, Clone, Copy)]
66pub struct NoopMathRenderer;
67
68impl MathRenderer for NoopMathRenderer {
69 fn render_html_label(&self, _text: &str, _config: &MermaidConfig) -> Option<String> {
70 None
71 }
72}
73
74#[derive(Debug, Clone, PartialEq, Eq, Hash)]
75struct RenderCacheKey {
76 text: String,
77 legacy_mathml: bool,
78 force_legacy_mathml: bool,
79}
80
81#[derive(Debug, Clone, PartialEq, Eq, Hash)]
82struct ProbeCacheKey {
83 render: RenderCacheKey,
84 font_family: Option<String>,
85 font_size_bits: u64,
86 font_weight: Option<String>,
87 max_width_bits: u64,
88}
89
90#[derive(Debug, Clone)]
91struct ProbeCacheValue {
92 html: String,
93 width: f64,
94 height: f64,
95 line_count: usize,
96}
97
98#[derive(Debug, Serialize)]
99struct NodeRenderRequest {
100 text: String,
101 config: NodeMathConfig,
102}
103
104#[derive(Debug, Serialize)]
105struct NodeProbeRequest {
106 text: String,
107 config: NodeMathConfig,
108 #[serde(rename = "styleCss")]
109 style_css: String,
110 #[serde(rename = "maxWidthPx")]
111 max_width_px: f64,
112}
113
114#[derive(Debug, Serialize)]
115struct NodeMathConfig {
116 #[serde(rename = "legacyMathML")]
117 legacy_mathml: bool,
118 #[serde(rename = "forceLegacyMathML")]
119 force_legacy_mathml: bool,
120}
121
122#[derive(Debug, Deserialize)]
123struct NodeRenderResponse {
124 html: String,
125}
126
127#[derive(Debug, Deserialize)]
128struct NodeProbeResponse {
129 html: String,
130 width: f64,
131 height: f64,
132}
133
134#[derive(Debug)]
144pub struct NodeKatexMathRenderer {
145 node_cwd: PathBuf,
146 node_command: PathBuf,
147 render_cache: Mutex<HashMap<RenderCacheKey, Option<String>>>,
148 probe_cache: Mutex<HashMap<ProbeCacheKey, Option<ProbeCacheValue>>>,
149 sequence_probe_cache: Mutex<HashMap<RenderCacheKey, Option<ProbeCacheValue>>>,
150}
151
152impl NodeKatexMathRenderer {
153 pub fn new(node_cwd: impl Into<PathBuf>) -> Self {
154 Self {
155 node_cwd: node_cwd.into(),
156 node_command: PathBuf::from("node"),
157 render_cache: Mutex::new(HashMap::new()),
158 probe_cache: Mutex::new(HashMap::new()),
159 sequence_probe_cache: Mutex::new(HashMap::new()),
160 }
161 }
162
163 pub fn with_node_command(mut self, node_command: impl Into<PathBuf>) -> Self {
164 self.node_command = node_command.into();
165 self
166 }
167
168 fn script_path() -> PathBuf {
169 Path::new(env!("CARGO_MANIFEST_DIR"))
170 .join("assets")
171 .join("katex_flowchart_probe.cjs")
172 }
173
174 fn normalized_text(text: &str) -> String {
175 text.replace("\\\\", "\\")
176 }
177
178 fn math_config(config: &MermaidConfig) -> NodeMathConfig {
179 let config_value = config.as_value();
180 let legacy_mathml = config_value
181 .get("legacyMathML")
182 .and_then(serde_json::Value::as_bool)
183 .unwrap_or(false);
184 let force_legacy_mathml = config_value
185 .get("forceLegacyMathML")
186 .and_then(serde_json::Value::as_bool)
187 .unwrap_or(false);
188 NodeMathConfig {
189 legacy_mathml,
190 force_legacy_mathml,
191 }
192 }
193
194 fn render_key(text: &str, config: &MermaidConfig) -> RenderCacheKey {
195 let config = Self::math_config(config);
196 RenderCacheKey {
197 text: Self::normalized_text(text),
198 legacy_mathml: config.legacy_mathml,
199 force_legacy_mathml: config.force_legacy_mathml,
200 }
201 }
202
203 fn style_css(style: &TextStyle) -> String {
204 let mut out = String::new();
205 let font_family = style
206 .font_family
207 .as_deref()
208 .unwrap_or("\"trebuchet ms\",verdana,arial,sans-serif");
209 let _ = write!(&mut out, "font-size: {}px;", style.font_size);
210 let _ = write!(&mut out, "font-family: {};", font_family);
211 if let Some(font_weight) = style.font_weight.as_deref() {
212 if !font_weight.trim().is_empty() {
213 let _ = write!(&mut out, "font-weight: {};", font_weight.trim());
214 }
215 }
216 out
217 }
218
219 fn run_node_request<T, R>(&self, mode: &str, payload: &T) -> Option<R>
220 where
221 T: Serialize,
222 R: for<'de> Deserialize<'de>,
223 {
224 if !self.node_cwd.join("package.json").is_file() {
225 return None;
226 }
227
228 let mut child = Command::new(&self.node_command)
229 .arg(Self::script_path())
230 .arg(mode)
231 .current_dir(&self.node_cwd)
232 .stdin(Stdio::piped())
233 .stdout(Stdio::piped())
234 .stderr(Stdio::piped())
235 .spawn()
236 .ok()?;
237
238 if let Some(mut stdin) = child.stdin.take() {
239 if serde_json::to_writer(&mut stdin, payload).is_err() {
240 return None;
241 }
242 let _ = stdin.flush();
243 }
244
245 let output = child.wait_with_output().ok()?;
246 if !output.status.success() {
247 return None;
248 }
249
250 serde_json::from_slice(&output.stdout).ok()
251 }
252
253 fn render_cached(&self, text: &str, config: &MermaidConfig) -> Option<String> {
254 let key = Self::render_key(text, config);
255 if let Some(cached) = self
256 .render_cache
257 .lock()
258 .ok()
259 .and_then(|cache| cache.get(&key).cloned())
260 {
261 return cached;
262 }
263
264 let response: Option<NodeRenderResponse> = self.run_node_request(
265 "render",
266 &NodeRenderRequest {
267 text: key.text.clone(),
268 config: NodeMathConfig {
269 legacy_mathml: key.legacy_mathml,
270 force_legacy_mathml: key.force_legacy_mathml,
271 },
272 },
273 );
274 let html = response.map(|value| value.html);
275
276 if let Ok(mut cache) = self.render_cache.lock() {
277 cache.insert(key, html.clone());
278 }
279
280 html
281 }
282
283 fn probe_cached(
284 &self,
285 text: &str,
286 config: &MermaidConfig,
287 style: &TextStyle,
288 max_width_px: Option<f64>,
289 _wrap_mode: WrapMode,
290 ) -> Option<ProbeCacheValue> {
291 let render = Self::render_key(text, config);
292 let max_width = max_width_px.unwrap_or(200.0).max(1.0);
293 let key = ProbeCacheKey {
294 render: render.clone(),
295 font_family: style.font_family.clone(),
296 font_size_bits: style.font_size.to_bits(),
297 font_weight: style.font_weight.clone(),
298 max_width_bits: max_width.to_bits(),
299 };
300 if let Some(cached) = self
301 .probe_cache
302 .lock()
303 .ok()
304 .and_then(|cache| cache.get(&key).cloned())
305 {
306 return cached;
307 }
308
309 let style_css = Self::style_css(style);
310 let response: Option<NodeProbeResponse> = self.run_node_request(
311 "probe",
312 &NodeProbeRequest {
313 text: render.text.clone(),
314 config: NodeMathConfig {
315 legacy_mathml: render.legacy_mathml,
316 force_legacy_mathml: render.force_legacy_mathml,
317 },
318 style_css,
319 max_width_px: max_width,
320 },
321 );
322 let probed = response.and_then(|value| {
323 if !value.width.is_finite() || !value.height.is_finite() {
324 return None;
325 }
326 let line_count = value.html.match_indices("<div").count().max(1);
327 Some(ProbeCacheValue {
328 html: value.html,
329 width: value.width.max(0.0),
330 height: value.height.max(0.0),
331 line_count,
332 })
333 });
334
335 if let Some(probed_value) = probed.clone() {
336 if let Ok(mut render_cache) = self.render_cache.lock() {
337 render_cache
338 .entry(render)
339 .or_insert_with(|| Some(probed_value.html.clone()));
340 }
341 }
342 if let Ok(mut cache) = self.probe_cache.lock() {
343 cache.insert(key, probed.clone());
344 }
345
346 probed
347 }
348
349 fn sequence_probe_cached(&self, text: &str, config: &MermaidConfig) -> Option<ProbeCacheValue> {
350 let key = Self::render_key(text, config);
351 if let Some(cached) = self
352 .sequence_probe_cache
353 .lock()
354 .ok()
355 .and_then(|cache| cache.get(&key).cloned())
356 {
357 return cached;
358 }
359
360 let response: Option<NodeProbeResponse> = self.run_node_request(
361 "probe-sequence",
362 &NodeRenderRequest {
363 text: key.text.clone(),
364 config: NodeMathConfig {
365 legacy_mathml: key.legacy_mathml,
366 force_legacy_mathml: key.force_legacy_mathml,
367 },
368 },
369 );
370 let probed = response.and_then(|value| {
371 if !value.width.is_finite() || !value.height.is_finite() {
372 return None;
373 }
374 let line_count = value.html.match_indices("<div").count().max(1);
375 Some(ProbeCacheValue {
376 html: value.html,
377 width: value.width.max(0.0),
378 height: value.height.max(0.0),
379 line_count,
380 })
381 });
382
383 if let Some(probed_value) = probed.clone() {
384 if let Ok(mut render_cache) = self.render_cache.lock() {
385 render_cache
386 .entry(key.clone())
387 .or_insert_with(|| Some(probed_value.html.clone()));
388 }
389 }
390 if let Ok(mut cache) = self.sequence_probe_cache.lock() {
391 cache.insert(key, probed.clone());
392 }
393
394 probed
395 }
396}
397
398impl MathRenderer for NodeKatexMathRenderer {
399 fn render_html_label(&self, text: &str, config: &MermaidConfig) -> Option<String> {
400 if !text.contains("$$") {
401 return None;
402 }
403 self.render_cached(text, config)
404 }
405
406 fn measure_html_label(
407 &self,
408 text: &str,
409 config: &MermaidConfig,
410 style: &TextStyle,
411 max_width_px: Option<f64>,
412 wrap_mode: WrapMode,
413 ) -> Option<TextMetrics> {
414 if wrap_mode != WrapMode::HtmlLike || !text.contains("$$") {
415 return None;
416 }
417 let probed = self.probe_cached(text, config, style, max_width_px, wrap_mode)?;
418 Some(TextMetrics {
419 width: crate::text::round_to_1_64_px(probed.width),
420 height: crate::text::round_to_1_64_px(probed.height),
421 line_count: probed.line_count,
422 })
423 }
424
425 fn measure_sequence_html_label(
426 &self,
427 text: &str,
428 config: &MermaidConfig,
429 ) -> Option<TextMetrics> {
430 if !text.contains("$$") {
431 return None;
432 }
433 let probed = self.sequence_probe_cached(text, config)?;
434 Some(TextMetrics {
435 width: crate::text::round_to_1_64_px(probed.width),
436 height: crate::text::round_to_1_64_px(probed.height),
437 line_count: probed.line_count,
438 })
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445
446 #[test]
447 fn node_katex_math_renderer_smoke() {
448 let node_cwd = Path::new(env!("CARGO_MANIFEST_DIR"))
449 .join("..")
450 .join("..")
451 .join("tools")
452 .join("mermaid-cli");
453 if !node_cwd.join("package.json").is_file() || !node_cwd.join("node_modules").is_dir() {
454 return;
455 }
456
457 let renderer = NodeKatexMathRenderer::new(node_cwd);
458 let config = MermaidConfig::default();
459 let style = TextStyle::default();
460
461 let Some(html) = renderer.render_html_label("$$x^2$$", &config) else {
462 return;
463 };
464 assert!(html.contains("katex"), "unexpected HTML: {html}");
465
466 let Some(metrics) = renderer.measure_html_label(
467 "$$x^2$$",
468 &config,
469 &style,
470 Some(200.0),
471 WrapMode::HtmlLike,
472 ) else {
473 return;
474 };
475 assert!(metrics.width.is_finite() && metrics.width > 0.0);
476 assert!(metrics.height.is_finite() && metrics.height > 0.0);
477 }
478
479 #[test]
480 fn node_katex_math_renderer_measures_sanitized_flowchart_browser_shell() {
481 let node_cwd = Path::new(env!("CARGO_MANIFEST_DIR"))
482 .join("..")
483 .join("..")
484 .join("tools")
485 .join("mermaid-cli");
486 if !node_cwd.join("package.json").is_file() || !node_cwd.join("node_modules").is_dir() {
487 return;
488 }
489
490 let renderer = NodeKatexMathRenderer::new(node_cwd);
491 let config = MermaidConfig::default();
492 let style = TextStyle::default();
493
494 let long_integral = "$$f(\\relax{x}) = \\int_{-\\infty}^\\infty \\hat{f}(\\xi)\\,e^{2 \\pi i \\xi x}\\,d\\xi$$";
495 let Some(node_metrics) = renderer.measure_html_label(
496 long_integral,
497 &config,
498 &style,
499 Some(200.0),
500 WrapMode::HtmlLike,
501 ) else {
502 return;
503 };
504 assert!(
505 (180.0..=260.0).contains(&node_metrics.width),
506 "node width = {}",
507 node_metrics.width
508 );
509 assert!(
510 (30.0..=70.0).contains(&node_metrics.height),
511 "node height = {}",
512 node_metrics.height
513 );
514 let Some(html) = renderer.render_html_label(long_integral, &config) else {
515 panic!("expected rendered math HTML after successful probe");
516 };
517 assert!(html.contains("<math"), "unexpected HTML: {html}");
518 assert!(!html.contains("<semantics>"), "unsanitized HTML: {html}");
519
520 let nested_delimiters = "$$\\Bigg(\\bigg(\\Big(\\big((\\frac{-b\\pm\\sqrt{b^2-4ac}}{2a})\\big)\\Big)\\bigg)\\Bigg)$$";
521 let Some(edge_metrics) = renderer.measure_html_label(
522 nested_delimiters,
523 &config,
524 &style,
525 Some(200.0),
526 WrapMode::HtmlLike,
527 ) else {
528 return;
529 };
530 assert!(
531 (180.0..=320.0).contains(&edge_metrics.width),
532 "edge width = {}",
533 edge_metrics.width
534 );
535 assert!(
536 (40.0..=100.0).contains(&edge_metrics.height),
537 "edge height = {}",
538 edge_metrics.height
539 );
540 }
541}