Skip to main content

dmc_transform/builtin/
math.rs

1//! LaTeX -> KaTeX/MathML. See `transformers/math.md` for full docs.
2
3use crate::pipeline::Transformer;
4use crate::visit::{NodeAction, Visitor, walk_root};
5use dmc_diagnostic::metadata::SourceMeta;
6use dmc_diagnostic::{Code, DiagResult};
7use dmc_parser::ast::*;
8use duck_diagnostic::{DiagnosticEngine, diag};
9use std::collections::HashMap;
10use std::sync::{Mutex, OnceLock};
11
12type MathCacheKey = (String, bool, crate::MathEngine);
13type MathCache = HashMap<MathCacheKey, String>;
14
15/// Render `$...$` and `$$...$$` math spans to MathML.
16///
17/// Two entry points:
18/// - [`Math::preprocess_source`] runs before the dmc lexer; rewrites raw
19///   `$...$` / `$$...$$` to `<MathMl mathml="..."/>` JSX so the parser
20///   never sees unescaped LaTeX. Required: `_` and `^` inside math would
21///   otherwise be parsed as Markdown emphasis markers.
22/// - [`Transformer`] impl runs as a pipeline pass on already-parsed AST.
23///   Used by tests and any caller that builds a Document directly.
24#[derive(Default, Debug)]
25pub struct Math;
26
27impl Math {
28  /// Rewrite `$...$` / `$$...$$` in raw MDX source to `<MathMl/>` JSX.
29  /// Skips fenced code blocks, inline code spans, and existing JSX tags.
30  pub fn preprocess_source(source: &str) -> String {
31    let mut out = String::with_capacity(source.len());
32    let bytes = source.as_bytes();
33    let mut i = 0;
34    if let Some(end) = Self::skip_frontmatter(source, bytes) {
35      out.push_str(&source[..end]);
36      i = end;
37    }
38    while i < bytes.len() {
39      if let Some(end) = Self::skip_fenced_code(source, bytes, i) {
40        out.push_str(&source[i..end]);
41        i = end;
42        continue;
43      }
44      if let Some(end) = Self::skip_inline_code(source, bytes, i) {
45        out.push_str(&source[i..end]);
46        i = end;
47        continue;
48      }
49      if let Some(end) = Self::skip_jsx_tag(source, bytes, i) {
50        out.push_str(&source[i..end]);
51        i = end;
52        continue;
53      }
54      if bytes[i] == b'\\' && bytes.get(i + 1) == Some(&b'$') {
55        out.push_str("\\$");
56        i += 2;
57        continue;
58      }
59      if bytes[i] == b'$' {
60        let display = bytes.get(i + 1) == Some(&b'$');
61        let delim_len = if display { 2 } else { 1 };
62        let body_start = i + delim_len;
63        let close_off =
64          if display { source[body_start..].find("$$") } else { Self::find_inline_close(&source[body_start..]) };
65        if let Some(off) = close_off {
66          let inner = &source[body_start..body_start + off];
67          if !inner.trim().is_empty() {
68            let rendered = Self::render(inner, display);
69            out.push_str(&format!("<MathMl mathml=\"{}\"/>", Self::escape_jsx_attr(&rendered)));
70            i = body_start + off + delim_len;
71            continue;
72          }
73        }
74        out.push('$');
75        i += 1;
76        continue;
77      }
78      let ch_len = utf8_char_len(bytes[i]);
79      out.push_str(&source[i..i + ch_len]);
80      i += ch_len;
81    }
82    out
83  }
84
85  /// Render a LaTeX string. Engine is the active [`crate::MathEngine`]
86  /// (default KaTeX HTML; can be flipped to MathML via `pulldown-latex`).
87  /// Cached by `(latex, display, engine)` so repeated math in a doc hits
88  /// the renderer once. On parse failure returns a
89  /// `<span class="math-error">` wrapper around the original LaTeX.
90  pub fn render(latex: &str, display: bool) -> String {
91    let engine = Self::active_engine();
92    let cache = Self::cache();
93    let key = (latex.to_string(), display, engine);
94    if let Some(hit) = cache.lock().expect("math cache lock").get(&key) {
95      return hit.clone();
96    }
97    let html = match engine {
98      crate::MathEngine::Katex => Self::render_katex(latex, display),
99      crate::MathEngine::Mathml => Self::render_mathml(latex, display),
100    };
101    cache.lock().expect("math cache lock").insert(key, html.clone());
102    html
103  }
104
105  fn render_katex(latex: &str, display: bool) -> String {
106    let opts_result = if display { Self::display_opts() } else { Self::inline_opts() };
107    let opts = match opts_result {
108      Ok(o) => o,
109      // KaTeX builder failure -> fall back to the error placeholder
110      // so the build still completes. The diagnostic itself is
111      // discarded here because `render_katex` has no
112      // `&mut DiagnosticEngine` handle; callers that need to capture
113      // it should invoke `inline_opts()` / `display_opts()` directly
114      // and propagate the `Diagnostic<Code>`.
115      Err(_) => return Self::error_span(latex, display),
116    };
117    match katex::render_with_opts(latex, opts) {
118      Ok(html) => html,
119      Err(_) => Self::error_span(latex, display),
120    }
121  }
122
123  fn render_mathml(latex: &str, display: bool) -> String {
124    use pulldown_latex::config::DisplayMode;
125    use pulldown_latex::{Parser, RenderConfig, Storage, mathml::push_mathml};
126    let storage = Storage::new();
127    let parser = Parser::new(latex, &storage);
128    let cfg = RenderConfig {
129      display_mode: if display { DisplayMode::Block } else { DisplayMode::Inline },
130      ..Default::default()
131    };
132    let mut out = String::new();
133    match push_mathml(&mut out, parser, cfg) {
134      Ok(()) => out,
135      Err(_) => Self::error_span(latex, display),
136    }
137  }
138
139  fn error_span(latex: &str, display: bool) -> String {
140    format!(
141      "<span class=\"math-error\">{}{}{}</span>",
142      if display { "$$" } else { "$" },
143      latex,
144      if display { "$$" } else { "$" }
145    )
146  }
147
148  /// Process-wide active engine. Set once via [`Math::set_engine`]
149  /// (pipeline does this from `PipelineConfig::math_engine`); defaults
150  /// to KaTeX. Stored as a static so [`Self::render`] does not need a
151  /// per-call engine argument (keeps the source preprocessor signature
152  /// engine-agnostic).
153  pub fn set_engine(engine: crate::MathEngine) {
154    Self::engine_slot().store(engine_to_u8(engine), std::sync::atomic::Ordering::Release);
155  }
156
157  fn active_engine() -> crate::MathEngine {
158    u8_to_engine(Self::engine_slot().load(std::sync::atomic::Ordering::Acquire))
159  }
160
161  fn engine_slot() -> &'static std::sync::atomic::AtomicU8 {
162    static S: OnceLock<std::sync::atomic::AtomicU8> = OnceLock::new();
163    S.get_or_init(|| std::sync::atomic::AtomicU8::new(engine_to_u8(crate::MathEngine::default())))
164  }
165
166  fn cache() -> &'static Mutex<MathCache> {
167    static C: OnceLock<Mutex<MathCache>> = OnceLock::new();
168    C.get_or_init(|| Mutex::new(HashMap::new()))
169  }
170
171  /// Load a previously persisted math cache from `path`. Missing or
172  /// corrupt files yield `Ok(())` (empty cache). Other IO errors
173  /// propagate as `IoRead`.
174  #[allow(clippy::result_large_err)]
175  pub fn load_cache(path: &std::path::Path) -> DiagResult {
176    let s = match std::fs::read_to_string(path) {
177      Ok(s) => s,
178      Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(()),
179      Err(e) => return Err(diag!(Code::IoRead, format!("math cache read at {}: {}", path.display(), e))),
180    };
181
182    let rows = match serde_json::from_str::<Vec<(String, bool, u8, String)>>(&s) {
183      Ok(r) => r,
184      Err(_) => return Ok(()),
185    };
186
187    let mut cache = Self::cache().lock().map_err(|e| diag!(Code::LockPoisoned, format!("math cache lock: {}", e)))?;
188
189    for (latex, display, eng, html) in rows {
190      cache.entry((latex, display, u8_to_engine(eng))).or_insert(html);
191    }
192    Ok(())
193  }
194
195  /// Persist the in-memory math cache to `path`. Best effort; errors
196  /// are swallowed.
197  #[allow(clippy::result_large_err)]
198  pub fn save_cache(path: &std::path::Path) -> DiagResult {
199    let cache = Self::cache().lock().expect("math cache lock");
200    let rows: Vec<(String, bool, u8, String)> = cache
201      .iter()
202      .map(|((latex, display, eng), html)| (latex.clone(), *display, engine_to_u8(*eng), html.clone()))
203      .collect();
204
205    let json =
206      serde_json::to_string(&rows).map_err(|e| diag!(Code::JsonSerialize, format!("math cache serialise: {}", e)))?;
207
208    if let Some(parent) = path.parent() {
209      std::fs::create_dir_all(parent)
210        .map_err(|e| diag!(Code::IoCreateDir, format!("math cache dir at {}: {}", parent.display(), e)))?;
211    }
212
213    std::fs::write(path, json)
214      .map_err(|e| diag!(Code::IoWrite, format!("math cache write at {}: {}", path.display(), e)))?;
215    Ok(())
216  }
217
218  #[allow(clippy::result_large_err)]
219  fn display_opts() -> DiagResult<&'static katex::Opts> {
220    static O: OnceLock<Result<katex::Opts, String>> = OnceLock::new();
221    let cached = O.get_or_init(|| {
222      katex::Opts::builder()
223        .display_mode(true)
224        .output_type(katex::OutputType::HtmlAndMathml)
225        .build()
226        .map_err(|e| e.to_string())
227    });
228    cached.as_ref().map_err(|e| diag!(Code::KatexOpts, format!("katex opts: {}", e)))
229  }
230
231  /// Build the KaTeX renderer once and cache. Inputs are all
232  /// hard-coded constants, so any builder failure here is a packaging
233  /// bug (e.g. a busted katex feature combo) - we surface it as
234  /// `Code::KatexOpts` (warning, not fatal) and let the caller decide
235  /// what to do with the unrenderable span.
236  #[allow(clippy::result_large_err)]
237  fn inline_opts() -> DiagResult<&'static katex::Opts> {
238    static O: OnceLock<Result<katex::Opts, String>> = OnceLock::new();
239    let cached = O.get_or_init(|| {
240      katex::Opts::builder()
241        .display_mode(false)
242        .output_type(katex::OutputType::HtmlAndMathml)
243        .build()
244        .map_err(|e| e.to_string())
245    });
246    cached.as_ref().map_err(|e| diag!(Code::KatexOpts, format!("katex opts: {}", e)))
247  }
248
249  /// Render LaTeX as a self-closing `<MathMl/>` JsxSelfClosing node.
250  pub fn render_node(latex: &str, display: bool, span: &duck_diagnostic::Span) -> Node {
251    let mathml = Self::render(latex, display);
252    Node::JsxSelfClosing(JsxSelfClosing {
253      name: "MathMl".into(),
254      attrs: vec![JsxAttr { name: "mathml".into(), value: JsxAttrValue::String(mathml), span: span.clone() }],
255      span: span.clone(),
256    })
257  }
258
259  // helpers
260
261  /// Skip YAML (`---`) or TOML (`+++`) frontmatter at byte 0 so `$` runs
262  /// inside `description`, etc. never get pair-matched as math spans.
263  /// Returns the byte offset of the first content byte after the closing
264  /// fence's newline, or `None` if no frontmatter is present.
265  fn skip_frontmatter(source: &str, bytes: &[u8]) -> Option<usize> {
266    let fence = if bytes.starts_with(b"---\n") || bytes.starts_with(b"---\r\n") {
267      "---"
268    } else if bytes.starts_with(b"+++\n") || bytes.starts_with(b"+++\r\n") {
269      "+++"
270    } else {
271      return None;
272    };
273    let body_start = if bytes[3] == b'\r' { 5 } else { 4 };
274    let rest = &source[body_start..];
275    // Closing fence must sit at the start of a line. Scan for `\n<fence>`
276    // and accept either bare-EOL or trailing-newline termination.
277    let mut search = 0usize;
278    while let Some(rel) = rest[search..].find(fence) {
279      let abs = search + rel;
280      let at_line_start = abs == 0 || rest.as_bytes()[abs - 1] == b'\n';
281      let after = abs + fence.len();
282      let terminates = after == rest.len() || rest.as_bytes()[after] == b'\n' || rest.as_bytes()[after] == b'\r';
283      if at_line_start && terminates {
284        let mut end = body_start + after;
285        if end < bytes.len() && bytes[end] == b'\r' {
286          end += 1;
287        }
288        if end < bytes.len() && bytes[end] == b'\n' {
289          end += 1;
290        }
291        return Some(end);
292      }
293      search = abs + fence.len();
294    }
295    None
296  }
297
298  fn skip_fenced_code(source: &str, bytes: &[u8], i: usize) -> Option<usize> {
299    if bytes[i] != b'`' || bytes.get(i + 1) != Some(&b'`') || bytes.get(i + 2) != Some(&b'`') {
300      return None;
301    }
302    let end = source[i + 3..].find("```").map(|p| i + 3 + p + 3).unwrap_or(bytes.len());
303    Some(end)
304  }
305
306  fn skip_inline_code(source: &str, bytes: &[u8], i: usize) -> Option<usize> {
307    if bytes[i] != b'`' {
308      return None;
309    }
310    let p = source[i + 1..].find('`')?;
311    Some(i + 1 + p + 1)
312  }
313
314  fn skip_jsx_tag(source: &str, bytes: &[u8], i: usize) -> Option<usize> {
315    if bytes[i] != b'<' {
316      return None;
317    }
318    let p = source[i + 1..].find('>')?;
319    Some(i + 1 + p + 1)
320  }
321
322  fn find_inline_close(inline: &str) -> Option<usize> {
323    let mut search = 0usize;
324    while search < inline.len() {
325      let rel = inline[search..].find(['$', '\n'])?;
326      let abs = search + rel;
327      if inline.as_bytes()[abs] == b'\n' {
328        return None;
329      }
330      if abs > 0 && inline.as_bytes()[abs - 1] == b'\\' {
331        search = abs + 1;
332        continue;
333      }
334      return Some(abs);
335    }
336    None
337  }
338
339  /// Escape `"` and `&` so MathML survives JSX attribute parsing.
340  /// Reversed by the codegen `MathMl` raw-HTML paster.
341  fn escape_jsx_attr(s: &str) -> String {
342    s.replace('&', "&amp;").replace('"', "&quot;")
343  }
344}
345
346impl Transformer for Math {
347  fn name(&self) -> &str {
348    "math"
349  }
350  fn transform(&self, doc: &mut Document, _meta: &SourceMeta, _engine: &mut DiagnosticEngine<Code>) {
351    let mut v = Apply;
352    walk_root(&mut doc.children, &mut v);
353  }
354}
355
356struct Apply;
357
358impl Visitor for Apply {
359  fn visit_node(&mut self, node: &mut Node) -> NodeAction {
360    if let Node::Paragraph(p) = node
361      && let [Node::Text(t)] = p.children.as_slice()
362      && let Some(latex) = Math::unwrap_block(t.value.trim())
363    {
364      let span = t.span.clone();
365      return NodeAction::Replace(vec![Math::render_node(latex, true, &span)]);
366    }
367    let Node::Text(t) = node else { return NodeAction::Keep };
368    let Some(replacement) = Math::expand_inline(&t.value, &t.span) else { return NodeAction::Keep };
369    NodeAction::Replace(replacement)
370  }
371}
372
373impl Math {
374  fn unwrap_block(s: &str) -> Option<&str> {
375    let s = s.trim();
376    let inner = s.strip_prefix("$$")?.strip_suffix("$$")?;
377    Some(inner.trim())
378  }
379
380  fn expand_inline(text: &str, span: &duck_diagnostic::Span) -> Option<Vec<Node>> {
381    if !text.contains('$') {
382      return None;
383    }
384    let mut out: Vec<Node> = Vec::new();
385    let mut buf = String::new();
386    let bytes = text.as_bytes();
387    let mut i = 0;
388    let mut found_any = false;
389
390    while i < bytes.len() {
391      let c = bytes[i];
392      if c == b'\\' && i + 1 < bytes.len() && bytes[i + 1] == b'$' {
393        buf.push('$');
394        i += 2;
395        continue;
396      }
397      if c != b'$' {
398        let ch_len = utf8_char_len(bytes[i]);
399        buf.push_str(&text[i..i + ch_len]);
400        i += ch_len;
401        continue;
402      }
403      let (delim, display) = if i + 1 < bytes.len() && bytes[i + 1] == b'$' { ("$$", true) } else { ("$", false) };
404      let inner_start = i + delim.len();
405      let Some(close_off) = Self::find_unescaped(&text[inner_start..], delim) else {
406        buf.push('$');
407        i += 1;
408        continue;
409      };
410      let inner = &text[inner_start..inner_start + close_off];
411      if !buf.is_empty() {
412        out.push(Node::Text(Text { value: std::mem::take(&mut buf), span: span.clone() }));
413      }
414      out.push(Self::render_node(inner, display, span));
415      i = inner_start + close_off + delim.len();
416      found_any = true;
417    }
418
419    if !found_any {
420      return None;
421    }
422    if !buf.is_empty() {
423      out.push(Node::Text(Text { value: buf, span: span.clone() }));
424    }
425    Some(out)
426  }
427
428  fn find_unescaped(haystack: &str, delim: &str) -> Option<usize> {
429    let mut search_from = 0;
430    while search_from < haystack.len() {
431      let off = haystack[search_from..].find(delim)?;
432      let abs = search_from + off;
433      if abs > 0 && haystack.as_bytes()[abs - 1] == b'\\' {
434        search_from = abs + delim.len();
435        continue;
436      }
437      return Some(abs);
438    }
439    None
440  }
441}
442
443fn engine_to_u8(e: crate::MathEngine) -> u8 {
444  match e {
445    crate::MathEngine::Katex => 0,
446    crate::MathEngine::Mathml => 1,
447  }
448}
449
450fn u8_to_engine(b: u8) -> crate::MathEngine {
451  match b {
452    1 => crate::MathEngine::Mathml,
453    _ => crate::MathEngine::Katex,
454  }
455}
456
457fn utf8_char_len(b: u8) -> usize {
458  if b < 0x80 {
459    1
460  } else if b < 0xE0 {
461    2
462  } else if b < 0xF0 {
463    3
464  } else {
465    4
466  }
467}
468
469#[cfg(test)]
470mod tests {
471  use super::*;
472  use duck_diagnostic::Span;
473  use std::sync::Arc;
474
475  fn s() -> Span {
476    Span { file: Arc::from("<t>"), line: 1, column: 1, length: 0 }
477  }
478
479  #[test]
480  fn passthrough_when_no_dollars() {
481    assert!(Math::expand_inline("nothing here", &s()).is_none());
482  }
483
484  #[test]
485  fn inline_math_replaces_one_span() {
486    let r = Math::expand_inline("a $x+1$ b", &s()).expect("matched");
487    // Text("a ") + JsxSelfClosing(MathMl) + Text(" b")
488    assert_eq!(r.len(), 3);
489    assert!(matches!(&r[0], Node::Text(t) if t.value == "a "));
490    assert!(matches!(&r[1], Node::JsxSelfClosing(e) if e.name == "MathMl"));
491    assert!(matches!(&r[2], Node::Text(t) if t.value == " b"));
492  }
493
494  #[test]
495  fn escaped_dollar_is_literal() {
496    let r = Math::expand_inline(r"price \$5", &s());
497    assert!(r.is_none() || matches!(&r.unwrap()[0], Node::Text(t) if t.value.contains("$5")));
498  }
499
500  #[test]
501  fn unmatched_dollar_left_alone() {
502    assert!(Math::expand_inline("a $ stray", &s()).is_none());
503  }
504
505  #[test]
506  fn block_math_unwraps() {
507    let s_ = s();
508    let mut p = Document {
509      children: vec![Node::Paragraph(Paragraph {
510        children: vec![Node::Text(Text { value: "$$ x = y $$".into(), span: s_.clone() })],
511        span: s_.clone(),
512      })],
513      span: s_,
514    };
515    let mut v = Apply;
516    walk_root(&mut p.children, &mut v);
517    assert_eq!(p.children.len(), 1);
518    if let Node::JsxSelfClosing(e) = &p.children[0] {
519      assert_eq!(e.name, "MathMl");
520      let mathml = e.attrs.iter().find(|a| a.name == "mathml").unwrap();
521      assert!(
522        matches!(&mathml.value, JsxAttrValue::String(s) if s.contains("<math") && s.contains("display=\"block\""))
523      );
524    } else {
525      panic!("expected MathMl element, got {:?}", p.children[0]);
526    }
527  }
528}