Skip to main content

dmc_transform/builtin/
math.rs

1//! Render LaTeX in `$...$` (inline) and `$$...$$` (block) to KaTeX HTML
2//! via the embedded KaTeX engine (`katex` crate, quick-js backend).
3//! Output is byte-equivalent to the JS chain `remark-math` +
4//! `rehype-katex`. Consumers must ship `katex.min.css` for glyphs.
5//!
6//! Block-level math (a paragraph whose entire content is `$$...$$`)
7//! replaces the wrapping `<p>` so the rendered HTML lands at block level.
8//!
9//! Escaped delimiters (`\$`) pass through as literal `$`. Parse failures
10//! emit `<span class="math-error">` wrapping the original LaTeX.
11
12use crate::pipeline::Transformer;
13use crate::visit::{NodeAction, Visitor, walk_root};
14use dmc_diagnostic::Code;
15use dmc_diagnostic::metadata::SourceMeta;
16use dmc_parser::ast::*;
17use duck_diagnostic::DiagnosticEngine;
18use std::collections::HashMap;
19use std::sync::{Mutex, OnceLock};
20
21type MathCacheKey = (String, bool, crate::MathEngine);
22type MathCache = HashMap<MathCacheKey, String>;
23
24/// Render `$...$` and `$$...$$` math spans to MathML.
25///
26/// Two entry points:
27/// - [`Math::preprocess_source`] runs before the dmc lexer; rewrites raw
28///   `$...$` / `$$...$$` to `<MathMl mathml="..."/>` JSX so the parser
29///   never sees unescaped LaTeX. Required: `_` and `^` inside math would
30///   otherwise be parsed as Markdown emphasis markers.
31/// - [`Transformer`] impl runs as a pipeline pass on already-parsed AST.
32///   Used by tests and any caller that builds a Document directly.
33#[derive(Default, Debug)]
34pub struct Math;
35
36impl Math {
37  /// Rewrite `$...$` / `$$...$$` in raw MDX source to `<MathMl/>` JSX.
38  /// Skips fenced code blocks, inline code spans, and existing JSX tags.
39  pub fn preprocess_source(source: &str) -> String {
40    let mut out = String::with_capacity(source.len());
41    let bytes = source.as_bytes();
42    let mut i = 0;
43    while i < bytes.len() {
44      if let Some(end) = Self::skip_fenced_code(source, bytes, i) {
45        out.push_str(&source[i..end]);
46        i = end;
47        continue;
48      }
49      if let Some(end) = Self::skip_inline_code(source, bytes, i) {
50        out.push_str(&source[i..end]);
51        i = end;
52        continue;
53      }
54      if let Some(end) = Self::skip_jsx_tag(source, bytes, i) {
55        out.push_str(&source[i..end]);
56        i = end;
57        continue;
58      }
59      if bytes[i] == b'\\' && bytes.get(i + 1) == Some(&b'$') {
60        out.push_str("\\$");
61        i += 2;
62        continue;
63      }
64      if bytes[i] == b'$' {
65        let display = bytes.get(i + 1) == Some(&b'$');
66        let delim_len = if display { 2 } else { 1 };
67        let body_start = i + delim_len;
68        let close_off =
69          if display { source[body_start..].find("$$") } else { Self::find_inline_close(&source[body_start..]) };
70        if let Some(off) = close_off {
71          let inner = &source[body_start..body_start + off];
72          if !inner.trim().is_empty() {
73            let rendered = Self::render(inner, display);
74            out.push_str(&format!("<MathMl mathml=\"{}\"/>", Self::escape_jsx_attr(&rendered)));
75            i = body_start + off + delim_len;
76            continue;
77          }
78        }
79        out.push('$');
80        i += 1;
81        continue;
82      }
83      let ch_len = utf8_char_len(bytes[i]);
84      out.push_str(&source[i..i + ch_len]);
85      i += ch_len;
86    }
87    out
88  }
89
90  /// Render a LaTeX string. Engine is the active [`crate::MathEngine`]
91  /// (default KaTeX HTML; can be flipped to MathML via `pulldown-latex`).
92  /// Cached by `(latex, display, engine)` so repeated math in a doc hits
93  /// the renderer once. On parse failure returns a
94  /// `<span class="math-error">` wrapper around the original LaTeX.
95  pub fn render(latex: &str, display: bool) -> String {
96    let engine = Self::active_engine();
97    let cache = Self::cache();
98    let key = (latex.to_string(), display, engine);
99    if let Some(hit) = cache.lock().expect("math cache lock").get(&key) {
100      return hit.clone();
101    }
102    let html = match engine {
103      crate::MathEngine::Katex => Self::render_katex(latex, display),
104      crate::MathEngine::Mathml => Self::render_mathml(latex, display),
105    };
106    cache.lock().expect("math cache lock").insert(key, html.clone());
107    html
108  }
109
110  fn render_katex(latex: &str, display: bool) -> String {
111    let opts = if display { Self::display_opts() } else { Self::inline_opts() };
112    match katex::render_with_opts(latex, opts) {
113      Ok(html) => html,
114      Err(_) => Self::error_span(latex, display),
115    }
116  }
117
118  fn render_mathml(latex: &str, display: bool) -> String {
119    use pulldown_latex::config::DisplayMode;
120    use pulldown_latex::{Parser, RenderConfig, Storage, mathml::push_mathml};
121    let storage = Storage::new();
122    let parser = Parser::new(latex, &storage);
123    let cfg = RenderConfig {
124      display_mode: if display { DisplayMode::Block } else { DisplayMode::Inline },
125      ..Default::default()
126    };
127    let mut out = String::new();
128    match push_mathml(&mut out, parser, cfg) {
129      Ok(()) => out,
130      Err(_) => Self::error_span(latex, display),
131    }
132  }
133
134  fn error_span(latex: &str, display: bool) -> String {
135    format!(
136      "<span class=\"math-error\">{}{}{}</span>",
137      if display { "$$" } else { "$" },
138      latex,
139      if display { "$$" } else { "$" }
140    )
141  }
142
143  /// Process-wide active engine. Set once via [`Math::set_engine`]
144  /// (pipeline does this from `PipelineConfig::math_engine`); defaults
145  /// to KaTeX. Stored as a static so [`render`] does not need a
146  /// per-call engine argument (keeps the source preprocessor signature
147  /// engine-agnostic).
148  pub fn set_engine(engine: crate::MathEngine) {
149    Self::engine_slot().store(engine_to_u8(engine), std::sync::atomic::Ordering::Release);
150  }
151
152  fn active_engine() -> crate::MathEngine {
153    u8_to_engine(Self::engine_slot().load(std::sync::atomic::Ordering::Acquire))
154  }
155
156  fn engine_slot() -> &'static std::sync::atomic::AtomicU8 {
157    static S: OnceLock<std::sync::atomic::AtomicU8> = OnceLock::new();
158    S.get_or_init(|| std::sync::atomic::AtomicU8::new(engine_to_u8(crate::MathEngine::default())))
159  }
160
161  fn cache() -> &'static Mutex<MathCache> {
162    static C: OnceLock<Mutex<MathCache>> = OnceLock::new();
163    C.get_or_init(|| Mutex::new(HashMap::new()))
164  }
165
166  /// Load a previously persisted math cache from `path`. Silently ignores
167  /// missing/corrupt files. Idempotent: existing in-memory entries are
168  /// kept (cache wins are additive).
169  pub fn load_cache(path: &std::path::Path) {
170    let Ok(s) = std::fs::read_to_string(path) else { return };
171    let Ok(rows) = serde_json::from_str::<Vec<(String, bool, u8, String)>>(&s) else { return };
172    let mut cache = Self::cache().lock().expect("math cache lock");
173    for (latex, display, eng, html) in rows {
174      cache.entry((latex, display, u8_to_engine(eng))).or_insert(html);
175    }
176  }
177
178  /// Persist the in-memory math cache to `path`. Best effort; errors
179  /// are swallowed.
180  pub fn save_cache(path: &std::path::Path) {
181    let cache = Self::cache().lock().expect("math cache lock");
182    let rows: Vec<(String, bool, u8, String)> = cache
183      .iter()
184      .map(|((latex, display, eng), html)| (latex.clone(), *display, engine_to_u8(*eng), html.clone()))
185      .collect();
186    let Ok(json) = serde_json::to_string(&rows) else { return };
187    if let Some(parent) = path.parent() {
188      let _ = std::fs::create_dir_all(parent);
189    }
190    let _ = std::fs::write(path, json);
191  }
192
193  fn display_opts() -> &'static katex::Opts {
194    static O: OnceLock<katex::Opts> = OnceLock::new();
195    O.get_or_init(|| {
196      katex::Opts::builder()
197        .display_mode(true)
198        .output_type(katex::OutputType::HtmlAndMathml)
199        .build()
200        .expect("katex opts")
201    })
202  }
203
204  fn inline_opts() -> &'static katex::Opts {
205    static O: OnceLock<katex::Opts> = OnceLock::new();
206    O.get_or_init(|| {
207      katex::Opts::builder()
208        .display_mode(false)
209        .output_type(katex::OutputType::HtmlAndMathml)
210        .build()
211        .expect("katex opts")
212    })
213  }
214
215  /// Render LaTeX as a self-closing `<MathMl/>` JsxSelfClosing node.
216  pub fn render_node(latex: &str, display: bool, span: &duck_diagnostic::Span) -> Node {
217    let mathml = Self::render(latex, display);
218    Node::JsxSelfClosing(JsxSelfClosing {
219      name: "MathMl".into(),
220      attrs: vec![JsxAttr { name: "mathml".into(), value: JsxAttrValue::String(mathml), span: span.clone() }],
221      span: span.clone(),
222    })
223  }
224
225  // ---------------------------------------------------------------- helpers
226
227  fn skip_fenced_code(source: &str, bytes: &[u8], i: usize) -> Option<usize> {
228    if bytes[i] != b'`' || bytes.get(i + 1) != Some(&b'`') || bytes.get(i + 2) != Some(&b'`') {
229      return None;
230    }
231    let end = source[i + 3..].find("```").map(|p| i + 3 + p + 3).unwrap_or(bytes.len());
232    Some(end)
233  }
234
235  fn skip_inline_code(source: &str, bytes: &[u8], i: usize) -> Option<usize> {
236    if bytes[i] != b'`' {
237      return None;
238    }
239    let p = source[i + 1..].find('`')?;
240    Some(i + 1 + p + 1)
241  }
242
243  fn skip_jsx_tag(source: &str, bytes: &[u8], i: usize) -> Option<usize> {
244    if bytes[i] != b'<' {
245      return None;
246    }
247    let p = source[i + 1..].find('>')?;
248    Some(i + 1 + p + 1)
249  }
250
251  fn find_inline_close(inline: &str) -> Option<usize> {
252    let mut search = 0usize;
253    while search < inline.len() {
254      let rel = inline[search..].find(['$', '\n'])?;
255      let abs = search + rel;
256      if inline.as_bytes()[abs] == b'\n' {
257        return None;
258      }
259      if abs > 0 && inline.as_bytes()[abs - 1] == b'\\' {
260        search = abs + 1;
261        continue;
262      }
263      return Some(abs);
264    }
265    None
266  }
267
268  /// Escape `"` and `&` so MathML survives JSX attribute parsing.
269  /// Reversed by the codegen `MathMl` raw-HTML paster.
270  fn escape_jsx_attr(s: &str) -> String {
271    s.replace('&', "&amp;").replace('"', "&quot;")
272  }
273}
274
275impl Transformer for Math {
276  fn name(&self) -> &str {
277    "math"
278  }
279  fn transform(&self, doc: &mut Document, _meta: &SourceMeta, _engine: &mut DiagnosticEngine<Code>) {
280    let mut v = Apply;
281    walk_root(&mut doc.children, &mut v);
282  }
283}
284
285struct Apply;
286
287impl Visitor for Apply {
288  fn visit_node(&mut self, node: &mut Node) -> NodeAction {
289    if let Node::Paragraph(p) = node
290      && let [Node::Text(t)] = p.children.as_slice()
291      && let Some(latex) = Math::unwrap_block(t.value.trim())
292    {
293      let span = t.span.clone();
294      return NodeAction::Replace(vec![Math::render_node(latex, true, &span)]);
295    }
296    let Node::Text(t) = node else { return NodeAction::Keep };
297    let Some(replacement) = Math::expand_inline(&t.value, &t.span) else { return NodeAction::Keep };
298    NodeAction::Replace(replacement)
299  }
300}
301
302impl Math {
303  fn unwrap_block(s: &str) -> Option<&str> {
304    let s = s.trim();
305    let inner = s.strip_prefix("$$")?.strip_suffix("$$")?;
306    Some(inner.trim())
307  }
308
309  fn expand_inline(text: &str, span: &duck_diagnostic::Span) -> Option<Vec<Node>> {
310    if !text.contains('$') {
311      return None;
312    }
313    let mut out: Vec<Node> = Vec::new();
314    let mut buf = String::new();
315    let bytes = text.as_bytes();
316    let mut i = 0;
317    let mut found_any = false;
318
319    while i < bytes.len() {
320      let c = bytes[i];
321      if c == b'\\' && i + 1 < bytes.len() && bytes[i + 1] == b'$' {
322        buf.push('$');
323        i += 2;
324        continue;
325      }
326      if c != b'$' {
327        let ch_len = utf8_char_len(bytes[i]);
328        buf.push_str(&text[i..i + ch_len]);
329        i += ch_len;
330        continue;
331      }
332      let (delim, display) = if i + 1 < bytes.len() && bytes[i + 1] == b'$' { ("$$", true) } else { ("$", false) };
333      let inner_start = i + delim.len();
334      let Some(close_off) = Self::find_unescaped(&text[inner_start..], delim) else {
335        buf.push('$');
336        i += 1;
337        continue;
338      };
339      let inner = &text[inner_start..inner_start + close_off];
340      if !buf.is_empty() {
341        out.push(Node::Text(Text { value: std::mem::take(&mut buf), span: span.clone() }));
342      }
343      out.push(Self::render_node(inner, display, span));
344      i = inner_start + close_off + delim.len();
345      found_any = true;
346    }
347
348    if !found_any {
349      return None;
350    }
351    if !buf.is_empty() {
352      out.push(Node::Text(Text { value: buf, span: span.clone() }));
353    }
354    Some(out)
355  }
356
357  fn find_unescaped(haystack: &str, delim: &str) -> Option<usize> {
358    let mut search_from = 0;
359    while search_from < haystack.len() {
360      let off = haystack[search_from..].find(delim)?;
361      let abs = search_from + off;
362      if abs > 0 && haystack.as_bytes()[abs - 1] == b'\\' {
363        search_from = abs + delim.len();
364        continue;
365      }
366      return Some(abs);
367    }
368    None
369  }
370}
371
372fn engine_to_u8(e: crate::MathEngine) -> u8 {
373  match e {
374    crate::MathEngine::Katex => 0,
375    crate::MathEngine::Mathml => 1,
376  }
377}
378
379fn u8_to_engine(b: u8) -> crate::MathEngine {
380  match b {
381    1 => crate::MathEngine::Mathml,
382    _ => crate::MathEngine::Katex,
383  }
384}
385
386fn utf8_char_len(b: u8) -> usize {
387  if b < 0x80 {
388    1
389  } else if b < 0xE0 {
390    2
391  } else if b < 0xF0 {
392    3
393  } else {
394    4
395  }
396}
397
398#[cfg(test)]
399mod tests {
400  use super::*;
401  use duck_diagnostic::Span;
402  use std::sync::Arc;
403
404  fn s() -> Span {
405    Span { file: Arc::from("<t>"), line: 1, column: 1, length: 0 }
406  }
407
408  #[test]
409  fn passthrough_when_no_dollars() {
410    assert!(Math::expand_inline("nothing here", &s()).is_none());
411  }
412
413  #[test]
414  fn inline_math_replaces_one_span() {
415    let r = Math::expand_inline("a $x+1$ b", &s()).expect("matched");
416    // Text("a ") + JsxSelfClosing(MathMl) + Text(" b")
417    assert_eq!(r.len(), 3);
418    assert!(matches!(&r[0], Node::Text(t) if t.value == "a "));
419    assert!(matches!(&r[1], Node::JsxSelfClosing(e) if e.name == "MathMl"));
420    assert!(matches!(&r[2], Node::Text(t) if t.value == " b"));
421  }
422
423  #[test]
424  fn escaped_dollar_is_literal() {
425    let r = Math::expand_inline(r"price \$5", &s());
426    assert!(r.is_none() || matches!(&r.unwrap()[0], Node::Text(t) if t.value.contains("$5")));
427  }
428
429  #[test]
430  fn unmatched_dollar_left_alone() {
431    assert!(Math::expand_inline("a $ stray", &s()).is_none());
432  }
433
434  #[test]
435  fn block_math_unwraps() {
436    let s_ = s();
437    let mut p = Document {
438      children: vec![Node::Paragraph(Paragraph {
439        children: vec![Node::Text(Text { value: "$$ x = y $$".into(), span: s_.clone() })],
440        span: s_.clone(),
441      })],
442      span: s_,
443    };
444    let mut v = Apply;
445    walk_root(&mut p.children, &mut v);
446    assert_eq!(p.children.len(), 1);
447    if let Node::JsxSelfClosing(e) = &p.children[0] {
448      assert_eq!(e.name, "MathMl");
449      let mathml = e.attrs.iter().find(|a| a.name == "mathml").unwrap();
450      assert!(
451        matches!(&mathml.value, JsxAttrValue::String(s) if s.contains("<math") && s.contains("display=\"block\""))
452      );
453    } else {
454      panic!("expected MathMl element, got {:?}", p.children[0]);
455    }
456  }
457}