Skip to main content

dmc_transform/builtin/
math.rs

1//! LaTeX -> KaTeX/MathML. See `transformers/math.md` for full docs.
2
3use crate::pipeline::Transformer;
4use crate::visit::{NodeAction, Visitor, walk_root};
5use dmc_diagnostic::metadata::SourceMeta;
6use dmc_diagnostic::{Code, DiagResult};
7use dmc_parser::ast::*;
8use duck_diagnostic::{DiagnosticEngine, diag};
9use std::collections::HashMap;
10use std::sync::{Mutex, OnceLock};
11
12type MathCacheKey = (String, bool, crate::MathEngine);
13type MathCache = HashMap<MathCacheKey, String>;
14
15/// Render `$...$` and `$$...$$` math spans to MathML.
16///
17/// Two entry points:
18/// - [`Math::preprocess_source`] runs before the dmc lexer; rewrites raw
19///   `$...$` / `$$...$$` to `<MathMl mathml="..."/>` JSX so the parser
20///   never sees unescaped LaTeX. Required: `_` and `^` inside math would
21///   otherwise be parsed as Markdown emphasis markers.
22/// - [`Transformer`] impl runs as a pipeline pass on already-parsed AST.
23///   Used by tests and any caller that builds a Document directly.
24#[derive(Default, Debug)]
25pub struct Math;
26
27impl Math {
28  /// Rewrite `$...$` / `$$...$$` in raw MDX source to `<MathMl/>` JSX.
29  /// Skips fenced code blocks, inline code spans, and existing JSX tags.
30  pub fn preprocess_source(source: &str) -> String {
31    let mut out = String::with_capacity(source.len());
32    let bytes = source.as_bytes();
33    let mut i = 0;
34    while i < bytes.len() {
35      if let Some(end) = Self::skip_fenced_code(source, bytes, i) {
36        out.push_str(&source[i..end]);
37        i = end;
38        continue;
39      }
40      if let Some(end) = Self::skip_inline_code(source, bytes, i) {
41        out.push_str(&source[i..end]);
42        i = end;
43        continue;
44      }
45      if let Some(end) = Self::skip_jsx_tag(source, bytes, i) {
46        out.push_str(&source[i..end]);
47        i = end;
48        continue;
49      }
50      if bytes[i] == b'\\' && bytes.get(i + 1) == Some(&b'$') {
51        out.push_str("\\$");
52        i += 2;
53        continue;
54      }
55      if bytes[i] == b'$' {
56        let display = bytes.get(i + 1) == Some(&b'$');
57        let delim_len = if display { 2 } else { 1 };
58        let body_start = i + delim_len;
59        let close_off =
60          if display { source[body_start..].find("$$") } else { Self::find_inline_close(&source[body_start..]) };
61        if let Some(off) = close_off {
62          let inner = &source[body_start..body_start + off];
63          if !inner.trim().is_empty() {
64            let rendered = Self::render(inner, display);
65            out.push_str(&format!("<MathMl mathml=\"{}\"/>", Self::escape_jsx_attr(&rendered)));
66            i = body_start + off + delim_len;
67            continue;
68          }
69        }
70        out.push('$');
71        i += 1;
72        continue;
73      }
74      let ch_len = utf8_char_len(bytes[i]);
75      out.push_str(&source[i..i + ch_len]);
76      i += ch_len;
77    }
78    out
79  }
80
81  /// Render a LaTeX string. Engine is the active [`crate::MathEngine`]
82  /// (default KaTeX HTML; can be flipped to MathML via `pulldown-latex`).
83  /// Cached by `(latex, display, engine)` so repeated math in a doc hits
84  /// the renderer once. On parse failure returns a
85  /// `<span class="math-error">` wrapper around the original LaTeX.
86  pub fn render(latex: &str, display: bool) -> String {
87    let engine = Self::active_engine();
88    let cache = Self::cache();
89    let key = (latex.to_string(), display, engine);
90    if let Some(hit) = cache.lock().expect("math cache lock").get(&key) {
91      return hit.clone();
92    }
93    let html = match engine {
94      crate::MathEngine::Katex => Self::render_katex(latex, display),
95      crate::MathEngine::Mathml => Self::render_mathml(latex, display),
96    };
97    cache.lock().expect("math cache lock").insert(key, html.clone());
98    html
99  }
100
101  fn render_katex(latex: &str, display: bool) -> String {
102    let opts_result = if display { Self::display_opts() } else { Self::inline_opts() };
103    let opts = match opts_result {
104      Ok(o) => o,
105      // KaTeX builder failure -> fall back to the error placeholder
106      // so the build still completes. The diagnostic itself is
107      // discarded here because `render_katex` has no
108      // `&mut DiagnosticEngine` handle; callers that need to capture
109      // it should invoke `inline_opts()` / `display_opts()` directly
110      // and propagate the `Diagnostic<Code>`.
111      Err(_) => return Self::error_span(latex, display),
112    };
113    match katex::render_with_opts(latex, opts) {
114      Ok(html) => html,
115      Err(_) => Self::error_span(latex, display),
116    }
117  }
118
119  fn render_mathml(latex: &str, display: bool) -> String {
120    use pulldown_latex::config::DisplayMode;
121    use pulldown_latex::{Parser, RenderConfig, Storage, mathml::push_mathml};
122    let storage = Storage::new();
123    let parser = Parser::new(latex, &storage);
124    let cfg = RenderConfig {
125      display_mode: if display { DisplayMode::Block } else { DisplayMode::Inline },
126      ..Default::default()
127    };
128    let mut out = String::new();
129    match push_mathml(&mut out, parser, cfg) {
130      Ok(()) => out,
131      Err(_) => Self::error_span(latex, display),
132    }
133  }
134
135  fn error_span(latex: &str, display: bool) -> String {
136    format!(
137      "<span class=\"math-error\">{}{}{}</span>",
138      if display { "$$" } else { "$" },
139      latex,
140      if display { "$$" } else { "$" }
141    )
142  }
143
144  /// Process-wide active engine. Set once via [`Math::set_engine`]
145  /// (pipeline does this from `PipelineConfig::math_engine`); defaults
146  /// to KaTeX. Stored as a static so [`Self::render`] does not need a
147  /// per-call engine argument (keeps the source preprocessor signature
148  /// engine-agnostic).
149  pub fn set_engine(engine: crate::MathEngine) {
150    Self::engine_slot().store(engine_to_u8(engine), std::sync::atomic::Ordering::Release);
151  }
152
153  fn active_engine() -> crate::MathEngine {
154    u8_to_engine(Self::engine_slot().load(std::sync::atomic::Ordering::Acquire))
155  }
156
157  fn engine_slot() -> &'static std::sync::atomic::AtomicU8 {
158    static S: OnceLock<std::sync::atomic::AtomicU8> = OnceLock::new();
159    S.get_or_init(|| std::sync::atomic::AtomicU8::new(engine_to_u8(crate::MathEngine::default())))
160  }
161
162  fn cache() -> &'static Mutex<MathCache> {
163    static C: OnceLock<Mutex<MathCache>> = OnceLock::new();
164    C.get_or_init(|| Mutex::new(HashMap::new()))
165  }
166
167  /// Load a previously persisted math cache from `path`. Missing or
168  /// corrupt files yield `Ok(())` (empty cache). Other IO errors
169  /// propagate as `IoRead`.
170  #[allow(clippy::result_large_err)]
171  pub fn load_cache(path: &std::path::Path) -> DiagResult {
172    let s = match std::fs::read_to_string(path) {
173      Ok(s) => s,
174      Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(()),
175      Err(e) => return Err(diag!(Code::IoRead, format!("math cache read at {}: {}", path.display(), e))),
176    };
177
178    let rows = match serde_json::from_str::<Vec<(String, bool, u8, String)>>(&s) {
179      Ok(r) => r,
180      Err(_) => return Ok(()),
181    };
182
183    let mut cache = Self::cache().lock().map_err(|e| diag!(Code::LockPoisoned, format!("math cache lock: {}", e)))?;
184
185    for (latex, display, eng, html) in rows {
186      cache.entry((latex, display, u8_to_engine(eng))).or_insert(html);
187    }
188    Ok(())
189  }
190
191  /// Persist the in-memory math cache to `path`. Best effort; errors
192  /// are swallowed.
193  #[allow(clippy::result_large_err)]
194  pub fn save_cache(path: &std::path::Path) -> DiagResult {
195    let cache = Self::cache().lock().expect("math cache lock");
196    let rows: Vec<(String, bool, u8, String)> = cache
197      .iter()
198      .map(|((latex, display, eng), html)| (latex.clone(), *display, engine_to_u8(*eng), html.clone()))
199      .collect();
200
201    let json =
202      serde_json::to_string(&rows).map_err(|e| diag!(Code::JsonSerialize, format!("math cache serialise: {}", e)))?;
203
204    if let Some(parent) = path.parent() {
205      std::fs::create_dir_all(parent)
206        .map_err(|e| diag!(Code::IoCreateDir, format!("math cache dir at {}: {}", parent.display(), e)))?;
207    }
208
209    std::fs::write(path, json)
210      .map_err(|e| diag!(Code::IoWrite, format!("math cache write at {}: {}", path.display(), e)))?;
211    Ok(())
212  }
213
214  #[allow(clippy::result_large_err)]
215  fn display_opts() -> DiagResult<&'static katex::Opts> {
216    static O: OnceLock<Result<katex::Opts, String>> = OnceLock::new();
217    let cached = O.get_or_init(|| {
218      katex::Opts::builder()
219        .display_mode(true)
220        .output_type(katex::OutputType::HtmlAndMathml)
221        .build()
222        .map_err(|e| e.to_string())
223    });
224    cached.as_ref().map_err(|e| diag!(Code::KatexOpts, format!("katex opts: {}", e)))
225  }
226
227  /// Build the KaTeX renderer once and cache. Inputs are all
228  /// hard-coded constants, so any builder failure here is a packaging
229  /// bug (e.g. a busted katex feature combo) - we surface it as
230  /// `Code::KatexOpts` (warning, not fatal) and let the caller decide
231  /// what to do with the unrenderable span.
232  #[allow(clippy::result_large_err)]
233  fn inline_opts() -> DiagResult<&'static katex::Opts> {
234    static O: OnceLock<Result<katex::Opts, String>> = OnceLock::new();
235    let cached = O.get_or_init(|| {
236      katex::Opts::builder()
237        .display_mode(false)
238        .output_type(katex::OutputType::HtmlAndMathml)
239        .build()
240        .map_err(|e| e.to_string())
241    });
242    cached.as_ref().map_err(|e| diag!(Code::KatexOpts, format!("katex opts: {}", e)))
243  }
244
245  /// Render LaTeX as a self-closing `<MathMl/>` JsxSelfClosing node.
246  pub fn render_node(latex: &str, display: bool, span: &duck_diagnostic::Span) -> Node {
247    let mathml = Self::render(latex, display);
248    Node::JsxSelfClosing(JsxSelfClosing {
249      name: "MathMl".into(),
250      attrs: vec![JsxAttr { name: "mathml".into(), value: JsxAttrValue::String(mathml), span: span.clone() }],
251      span: span.clone(),
252    })
253  }
254
255  // helpers
256
257  fn skip_fenced_code(source: &str, bytes: &[u8], i: usize) -> Option<usize> {
258    if bytes[i] != b'`' || bytes.get(i + 1) != Some(&b'`') || bytes.get(i + 2) != Some(&b'`') {
259      return None;
260    }
261    let end = source[i + 3..].find("```").map(|p| i + 3 + p + 3).unwrap_or(bytes.len());
262    Some(end)
263  }
264
265  fn skip_inline_code(source: &str, bytes: &[u8], i: usize) -> Option<usize> {
266    if bytes[i] != b'`' {
267      return None;
268    }
269    let p = source[i + 1..].find('`')?;
270    Some(i + 1 + p + 1)
271  }
272
273  fn skip_jsx_tag(source: &str, bytes: &[u8], i: usize) -> Option<usize> {
274    if bytes[i] != b'<' {
275      return None;
276    }
277    let p = source[i + 1..].find('>')?;
278    Some(i + 1 + p + 1)
279  }
280
281  fn find_inline_close(inline: &str) -> Option<usize> {
282    let mut search = 0usize;
283    while search < inline.len() {
284      let rel = inline[search..].find(['$', '\n'])?;
285      let abs = search + rel;
286      if inline.as_bytes()[abs] == b'\n' {
287        return None;
288      }
289      if abs > 0 && inline.as_bytes()[abs - 1] == b'\\' {
290        search = abs + 1;
291        continue;
292      }
293      return Some(abs);
294    }
295    None
296  }
297
298  /// Escape `"` and `&` so MathML survives JSX attribute parsing.
299  /// Reversed by the codegen `MathMl` raw-HTML paster.
300  fn escape_jsx_attr(s: &str) -> String {
301    s.replace('&', "&amp;").replace('"', "&quot;")
302  }
303}
304
305impl Transformer for Math {
306  fn name(&self) -> &str {
307    "math"
308  }
309  fn transform(&self, doc: &mut Document, _meta: &SourceMeta, _engine: &mut DiagnosticEngine<Code>) {
310    let mut v = Apply;
311    walk_root(&mut doc.children, &mut v);
312  }
313}
314
315struct Apply;
316
317impl Visitor for Apply {
318  fn visit_node(&mut self, node: &mut Node) -> NodeAction {
319    if let Node::Paragraph(p) = node
320      && let [Node::Text(t)] = p.children.as_slice()
321      && let Some(latex) = Math::unwrap_block(t.value.trim())
322    {
323      let span = t.span.clone();
324      return NodeAction::Replace(vec![Math::render_node(latex, true, &span)]);
325    }
326    let Node::Text(t) = node else { return NodeAction::Keep };
327    let Some(replacement) = Math::expand_inline(&t.value, &t.span) else { return NodeAction::Keep };
328    NodeAction::Replace(replacement)
329  }
330}
331
332impl Math {
333  fn unwrap_block(s: &str) -> Option<&str> {
334    let s = s.trim();
335    let inner = s.strip_prefix("$$")?.strip_suffix("$$")?;
336    Some(inner.trim())
337  }
338
339  fn expand_inline(text: &str, span: &duck_diagnostic::Span) -> Option<Vec<Node>> {
340    if !text.contains('$') {
341      return None;
342    }
343    let mut out: Vec<Node> = Vec::new();
344    let mut buf = String::new();
345    let bytes = text.as_bytes();
346    let mut i = 0;
347    let mut found_any = false;
348
349    while i < bytes.len() {
350      let c = bytes[i];
351      if c == b'\\' && i + 1 < bytes.len() && bytes[i + 1] == b'$' {
352        buf.push('$');
353        i += 2;
354        continue;
355      }
356      if c != b'$' {
357        let ch_len = utf8_char_len(bytes[i]);
358        buf.push_str(&text[i..i + ch_len]);
359        i += ch_len;
360        continue;
361      }
362      let (delim, display) = if i + 1 < bytes.len() && bytes[i + 1] == b'$' { ("$$", true) } else { ("$", false) };
363      let inner_start = i + delim.len();
364      let Some(close_off) = Self::find_unescaped(&text[inner_start..], delim) else {
365        buf.push('$');
366        i += 1;
367        continue;
368      };
369      let inner = &text[inner_start..inner_start + close_off];
370      if !buf.is_empty() {
371        out.push(Node::Text(Text { value: std::mem::take(&mut buf), span: span.clone() }));
372      }
373      out.push(Self::render_node(inner, display, span));
374      i = inner_start + close_off + delim.len();
375      found_any = true;
376    }
377
378    if !found_any {
379      return None;
380    }
381    if !buf.is_empty() {
382      out.push(Node::Text(Text { value: buf, span: span.clone() }));
383    }
384    Some(out)
385  }
386
387  fn find_unescaped(haystack: &str, delim: &str) -> Option<usize> {
388    let mut search_from = 0;
389    while search_from < haystack.len() {
390      let off = haystack[search_from..].find(delim)?;
391      let abs = search_from + off;
392      if abs > 0 && haystack.as_bytes()[abs - 1] == b'\\' {
393        search_from = abs + delim.len();
394        continue;
395      }
396      return Some(abs);
397    }
398    None
399  }
400}
401
402fn engine_to_u8(e: crate::MathEngine) -> u8 {
403  match e {
404    crate::MathEngine::Katex => 0,
405    crate::MathEngine::Mathml => 1,
406  }
407}
408
409fn u8_to_engine(b: u8) -> crate::MathEngine {
410  match b {
411    1 => crate::MathEngine::Mathml,
412    _ => crate::MathEngine::Katex,
413  }
414}
415
416fn utf8_char_len(b: u8) -> usize {
417  if b < 0x80 {
418    1
419  } else if b < 0xE0 {
420    2
421  } else if b < 0xF0 {
422    3
423  } else {
424    4
425  }
426}
427
428#[cfg(test)]
429mod tests {
430  use super::*;
431  use duck_diagnostic::Span;
432  use std::sync::Arc;
433
434  fn s() -> Span {
435    Span { file: Arc::from("<t>"), line: 1, column: 1, length: 0 }
436  }
437
438  #[test]
439  fn passthrough_when_no_dollars() {
440    assert!(Math::expand_inline("nothing here", &s()).is_none());
441  }
442
443  #[test]
444  fn inline_math_replaces_one_span() {
445    let r = Math::expand_inline("a $x+1$ b", &s()).expect("matched");
446    // Text("a ") + JsxSelfClosing(MathMl) + Text(" b")
447    assert_eq!(r.len(), 3);
448    assert!(matches!(&r[0], Node::Text(t) if t.value == "a "));
449    assert!(matches!(&r[1], Node::JsxSelfClosing(e) if e.name == "MathMl"));
450    assert!(matches!(&r[2], Node::Text(t) if t.value == " b"));
451  }
452
453  #[test]
454  fn escaped_dollar_is_literal() {
455    let r = Math::expand_inline(r"price \$5", &s());
456    assert!(r.is_none() || matches!(&r.unwrap()[0], Node::Text(t) if t.value.contains("$5")));
457  }
458
459  #[test]
460  fn unmatched_dollar_left_alone() {
461    assert!(Math::expand_inline("a $ stray", &s()).is_none());
462  }
463
464  #[test]
465  fn block_math_unwraps() {
466    let s_ = s();
467    let mut p = Document {
468      children: vec![Node::Paragraph(Paragraph {
469        children: vec![Node::Text(Text { value: "$$ x = y $$".into(), span: s_.clone() })],
470        span: s_.clone(),
471      })],
472      span: s_,
473    };
474    let mut v = Apply;
475    walk_root(&mut p.children, &mut v);
476    assert_eq!(p.children.len(), 1);
477    if let Node::JsxSelfClosing(e) = &p.children[0] {
478      assert_eq!(e.name, "MathMl");
479      let mathml = e.attrs.iter().find(|a| a.name == "mathml").unwrap();
480      assert!(
481        matches!(&mathml.value, JsxAttrValue::String(s) if s.contains("<math") && s.contains("display=\"block\""))
482      );
483    } else {
484      panic!("expected MathMl element, got {:?}", p.children[0]);
485    }
486  }
487}