1use crate::text::{TextMetrics, TextStyle, WrapMode};
11use merman_core::MermaidConfig;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::fmt::Write as _;
15use std::io::Write as _;
16use std::path::{Path, PathBuf};
17use std::process::{Command, Stdio};
18use std::sync::Mutex;
19
20pub trait MathRenderer: std::fmt::Debug {
27 fn render_html_label(&self, text: &str, config: &MermaidConfig) -> Option<String>;
34
35 fn measure_html_label(
40 &self,
41 _text: &str,
42 _config: &MermaidConfig,
43 _style: &TextStyle,
44 _max_width_px: Option<f64>,
45 _wrap_mode: WrapMode,
46 ) -> Option<TextMetrics> {
47 None
48 }
49}
50
51#[derive(Debug, Default, Clone, Copy)]
53pub struct NoopMathRenderer;
54
55impl MathRenderer for NoopMathRenderer {
56 fn render_html_label(&self, _text: &str, _config: &MermaidConfig) -> Option<String> {
57 None
58 }
59}
60
61#[derive(Debug, Clone, PartialEq, Eq, Hash)]
62struct RenderCacheKey {
63 text: String,
64 legacy_mathml: bool,
65 force_legacy_mathml: bool,
66}
67
68#[derive(Debug, Clone, PartialEq, Eq, Hash)]
69struct ProbeCacheKey {
70 render: RenderCacheKey,
71 font_family: Option<String>,
72 font_size_bits: u64,
73 font_weight: Option<String>,
74 max_width_bits: u64,
75}
76
77#[derive(Debug, Clone)]
78struct ProbeCacheValue {
79 html: String,
80 width: f64,
81 height: f64,
82 line_count: usize,
83}
84
85#[derive(Debug, Serialize)]
86struct NodeRenderRequest {
87 text: String,
88 config: NodeMathConfig,
89}
90
91#[derive(Debug, Serialize)]
92struct NodeProbeRequest {
93 text: String,
94 config: NodeMathConfig,
95 #[serde(rename = "styleCss")]
96 style_css: String,
97 #[serde(rename = "maxWidthPx")]
98 max_width_px: f64,
99}
100
101#[derive(Debug, Serialize)]
102struct NodeMathConfig {
103 #[serde(rename = "legacyMathML")]
104 legacy_mathml: bool,
105 #[serde(rename = "forceLegacyMathML")]
106 force_legacy_mathml: bool,
107}
108
109#[derive(Debug, Deserialize)]
110struct NodeRenderResponse {
111 html: String,
112}
113
114#[derive(Debug, Deserialize)]
115struct NodeProbeResponse {
116 html: String,
117 width: f64,
118 height: f64,
119}
120
121#[derive(Debug)]
131pub struct NodeKatexMathRenderer {
132 node_cwd: PathBuf,
133 node_command: PathBuf,
134 render_cache: Mutex<HashMap<RenderCacheKey, Option<String>>>,
135 probe_cache: Mutex<HashMap<ProbeCacheKey, Option<ProbeCacheValue>>>,
136}
137
138impl NodeKatexMathRenderer {
139 pub fn new(node_cwd: impl Into<PathBuf>) -> Self {
140 Self {
141 node_cwd: node_cwd.into(),
142 node_command: PathBuf::from("node"),
143 render_cache: Mutex::new(HashMap::new()),
144 probe_cache: Mutex::new(HashMap::new()),
145 }
146 }
147
148 pub fn with_node_command(mut self, node_command: impl Into<PathBuf>) -> Self {
149 self.node_command = node_command.into();
150 self
151 }
152
153 fn script_path() -> PathBuf {
154 Path::new(env!("CARGO_MANIFEST_DIR"))
155 .join("assets")
156 .join("katex_flowchart_probe.cjs")
157 }
158
159 fn normalized_text(text: &str) -> String {
160 text.replace("\\\\", "\\")
161 }
162
163 fn math_config(config: &MermaidConfig) -> NodeMathConfig {
164 let config_value = config.as_value();
165 let legacy_mathml = config_value
166 .get("legacyMathML")
167 .and_then(serde_json::Value::as_bool)
168 .unwrap_or(false);
169 let force_legacy_mathml = config_value
170 .get("forceLegacyMathML")
171 .and_then(serde_json::Value::as_bool)
172 .unwrap_or(false);
173 NodeMathConfig {
174 legacy_mathml,
175 force_legacy_mathml,
176 }
177 }
178
179 fn render_key(text: &str, config: &MermaidConfig) -> RenderCacheKey {
180 let config = Self::math_config(config);
181 RenderCacheKey {
182 text: Self::normalized_text(text),
183 legacy_mathml: config.legacy_mathml,
184 force_legacy_mathml: config.force_legacy_mathml,
185 }
186 }
187
188 fn style_css(style: &TextStyle) -> String {
189 let mut out = String::new();
190 let font_family = style
191 .font_family
192 .as_deref()
193 .unwrap_or("\"trebuchet ms\",verdana,arial,sans-serif");
194 let _ = write!(&mut out, "font-size: {}px;", style.font_size);
195 let _ = write!(&mut out, "font-family: {};", font_family);
196 if let Some(font_weight) = style.font_weight.as_deref() {
197 if !font_weight.trim().is_empty() {
198 let _ = write!(&mut out, "font-weight: {};", font_weight.trim());
199 }
200 }
201 out
202 }
203
204 fn run_node_request<T, R>(&self, mode: &str, payload: &T) -> Option<R>
205 where
206 T: Serialize,
207 R: for<'de> Deserialize<'de>,
208 {
209 if !self.node_cwd.join("package.json").is_file() {
210 return None;
211 }
212
213 let mut child = Command::new(&self.node_command)
214 .arg(Self::script_path())
215 .arg(mode)
216 .current_dir(&self.node_cwd)
217 .stdin(Stdio::piped())
218 .stdout(Stdio::piped())
219 .stderr(Stdio::piped())
220 .spawn()
221 .ok()?;
222
223 if let Some(mut stdin) = child.stdin.take() {
224 if serde_json::to_writer(&mut stdin, payload).is_err() {
225 return None;
226 }
227 let _ = stdin.flush();
228 }
229
230 let output = child.wait_with_output().ok()?;
231 if !output.status.success() {
232 return None;
233 }
234
235 serde_json::from_slice(&output.stdout).ok()
236 }
237
238 fn render_cached(&self, text: &str, config: &MermaidConfig) -> Option<String> {
239 let key = Self::render_key(text, config);
240 if let Some(cached) = self
241 .render_cache
242 .lock()
243 .ok()
244 .and_then(|cache| cache.get(&key).cloned())
245 {
246 return cached;
247 }
248
249 let response: Option<NodeRenderResponse> = self.run_node_request(
250 "render",
251 &NodeRenderRequest {
252 text: key.text.clone(),
253 config: NodeMathConfig {
254 legacy_mathml: key.legacy_mathml,
255 force_legacy_mathml: key.force_legacy_mathml,
256 },
257 },
258 );
259 let html = response.map(|value| value.html);
260
261 if let Ok(mut cache) = self.render_cache.lock() {
262 cache.insert(key, html.clone());
263 }
264
265 html
266 }
267
268 fn probe_cached(
269 &self,
270 text: &str,
271 config: &MermaidConfig,
272 style: &TextStyle,
273 max_width_px: Option<f64>,
274 _wrap_mode: WrapMode,
275 ) -> Option<ProbeCacheValue> {
276 let render = Self::render_key(text, config);
277 let max_width = max_width_px.unwrap_or(200.0).max(1.0);
278 let key = ProbeCacheKey {
279 render: render.clone(),
280 font_family: style.font_family.clone(),
281 font_size_bits: style.font_size.to_bits(),
282 font_weight: style.font_weight.clone(),
283 max_width_bits: max_width.to_bits(),
284 };
285 if let Some(cached) = self
286 .probe_cache
287 .lock()
288 .ok()
289 .and_then(|cache| cache.get(&key).cloned())
290 {
291 return cached;
292 }
293
294 let style_css = Self::style_css(style);
295 let response: Option<NodeProbeResponse> = self.run_node_request(
296 "probe",
297 &NodeProbeRequest {
298 text: render.text.clone(),
299 config: NodeMathConfig {
300 legacy_mathml: render.legacy_mathml,
301 force_legacy_mathml: render.force_legacy_mathml,
302 },
303 style_css,
304 max_width_px: max_width,
305 },
306 );
307 let probed = response.and_then(|value| {
308 if !value.width.is_finite() || !value.height.is_finite() {
309 return None;
310 }
311 let line_count = value.html.match_indices("<div").count().max(1);
312 Some(ProbeCacheValue {
313 html: value.html,
314 width: value.width.max(0.0),
315 height: value.height.max(0.0),
316 line_count,
317 })
318 });
319
320 if let Some(probed_value) = probed.clone() {
321 if let Ok(mut render_cache) = self.render_cache.lock() {
322 render_cache
323 .entry(render)
324 .or_insert_with(|| Some(probed_value.html.clone()));
325 }
326 }
327 if let Ok(mut cache) = self.probe_cache.lock() {
328 cache.insert(key, probed.clone());
329 }
330
331 probed
332 }
333}
334
335impl MathRenderer for NodeKatexMathRenderer {
336 fn render_html_label(&self, text: &str, config: &MermaidConfig) -> Option<String> {
337 if !text.contains("$$") {
338 return None;
339 }
340 self.render_cached(text, config)
341 }
342
343 fn measure_html_label(
344 &self,
345 text: &str,
346 config: &MermaidConfig,
347 style: &TextStyle,
348 max_width_px: Option<f64>,
349 wrap_mode: WrapMode,
350 ) -> Option<TextMetrics> {
351 if wrap_mode != WrapMode::HtmlLike || !text.contains("$$") {
352 return None;
353 }
354 let probed = self.probe_cached(text, config, style, max_width_px, wrap_mode)?;
355 Some(TextMetrics {
356 width: crate::text::round_to_1_64_px(probed.width),
357 height: crate::text::round_to_1_64_px(probed.height),
358 line_count: probed.line_count,
359 })
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366
367 #[test]
368 fn node_katex_math_renderer_smoke() {
369 let node_cwd = Path::new(env!("CARGO_MANIFEST_DIR"))
370 .join("..")
371 .join("..")
372 .join("tools")
373 .join("mermaid-cli");
374 if !node_cwd.join("package.json").is_file() || !node_cwd.join("node_modules").is_dir() {
375 return;
376 }
377
378 let renderer = NodeKatexMathRenderer::new(node_cwd);
379 let config = MermaidConfig::default();
380 let style = TextStyle::default();
381
382 let html = renderer
383 .render_html_label("$$x^2$$", &config)
384 .expect("node KaTeX renderer should produce HTML");
385 assert!(html.contains("katex"), "unexpected HTML: {html}");
386
387 let metrics = renderer
388 .measure_html_label("$$x^2$$", &config, &style, Some(200.0), WrapMode::HtmlLike)
389 .expect("node KaTeX renderer should produce metrics");
390 assert!(metrics.width.is_finite() && metrics.width > 0.0);
391 assert!(metrics.height.is_finite() && metrics.height > 0.0);
392 }
393
394 #[test]
395 fn node_katex_math_renderer_matches_flowchart_browser_shell_metrics() {
396 let node_cwd = Path::new(env!("CARGO_MANIFEST_DIR"))
397 .join("..")
398 .join("..")
399 .join("tools")
400 .join("mermaid-cli");
401 if !node_cwd.join("package.json").is_file() || !node_cwd.join("node_modules").is_dir() {
402 return;
403 }
404
405 let renderer = NodeKatexMathRenderer::new(node_cwd);
406 let config = MermaidConfig::default();
407 let style = TextStyle::default();
408
409 let node_metrics = renderer
410 .measure_html_label(
411 "$$f(\\relax{x}) = \\int_{-\\infty}^\\infty \\hat{f}(\\xi)\\,e^{2 \\pi i \\xi x}\\,d\\xi$$",
412 &config,
413 &style,
414 Some(200.0),
415 WrapMode::HtmlLike,
416 )
417 .expect("node label metrics");
418 assert!(
419 (node_metrics.width - 195.140625).abs() < 1e-9,
420 "node width = {}",
421 node_metrics.width
422 );
423 assert!(
424 (node_metrics.height - 27.53125).abs() < 1e-9,
425 "node height = {}",
426 node_metrics.height
427 );
428
429 let edge_metrics = renderer
430 .measure_html_label(
431 "$$\\Bigg(\\bigg(\\Big(\\big((\\frac{-b\\pm\\sqrt{b^2-4ac}}{2a})\\big)\\Big)\\bigg)\\Bigg)$$",
432 &config,
433 &style,
434 Some(200.0),
435 WrapMode::HtmlLike,
436 )
437 .expect("edge label metrics");
438 assert!(
439 (edge_metrics.width - 184.78125).abs() < 1e-9,
440 "edge width = {}",
441 edge_metrics.width
442 );
443 assert!(
444 (edge_metrics.height - 41.53125).abs() < 1e-9,
445 "edge height = {}",
446 edge_metrics.height
447 );
448 }
449}