1use crate::pipeline::Transformer;
4use crate::visit::{NodeAction, Visitor, walk_root};
5use dmc_diagnostic::metadata::SourceMeta;
6use dmc_diagnostic::{Code, DiagResult};
7use dmc_parser::ast::*;
8use duck_diagnostic::{DiagnosticEngine, diag};
9use std::collections::HashMap;
10use std::sync::{Mutex, OnceLock};
11
12type MathCacheKey = (String, bool, crate::MathEngine);
13type MathCache = HashMap<MathCacheKey, String>;
14
15#[derive(Default, Debug)]
25pub struct Math;
26
27impl Math {
28 pub fn preprocess_source(source: &str) -> String {
31 let mut out = String::with_capacity(source.len());
32 let bytes = source.as_bytes();
33 let mut i = 0;
34 if let Some(end) = Self::skip_frontmatter(source, bytes) {
35 out.push_str(&source[..end]);
36 i = end;
37 }
38 while i < bytes.len() {
39 if let Some(end) = Self::skip_fenced_code(source, bytes, i) {
40 out.push_str(&source[i..end]);
41 i = end;
42 continue;
43 }
44 if let Some(end) = Self::skip_inline_code(source, bytes, i) {
45 out.push_str(&source[i..end]);
46 i = end;
47 continue;
48 }
49 if let Some(end) = Self::skip_jsx_tag(source, bytes, i) {
50 out.push_str(&source[i..end]);
51 i = end;
52 continue;
53 }
54 if bytes[i] == b'\\' && bytes.get(i + 1) == Some(&b'$') {
55 out.push_str("\\$");
56 i += 2;
57 continue;
58 }
59 if bytes[i] == b'$' {
60 let display = bytes.get(i + 1) == Some(&b'$');
61 let delim_len = if display { 2 } else { 1 };
62 let body_start = i + delim_len;
63 let close_off =
64 if display { source[body_start..].find("$$") } else { Self::find_inline_close(&source[body_start..]) };
65 if let Some(off) = close_off {
66 let inner = &source[body_start..body_start + off];
67 if !inner.trim().is_empty() {
68 let rendered = Self::render(inner, display);
69 out.push_str(&format!("<MathMl mathml=\"{}\"/>", Self::escape_jsx_attr(&rendered)));
70 i = body_start + off + delim_len;
71 continue;
72 }
73 }
74 out.push('$');
75 i += 1;
76 continue;
77 }
78 let ch_len = utf8_char_len(bytes[i]);
79 out.push_str(&source[i..i + ch_len]);
80 i += ch_len;
81 }
82 out
83 }
84
85 pub fn render(latex: &str, display: bool) -> String {
91 let engine = Self::active_engine();
92 let cache = Self::cache();
93 let key = (latex.to_string(), display, engine);
94 if let Some(hit) = cache.lock().expect("math cache lock").get(&key) {
95 return hit.clone();
96 }
97 let html = match engine {
98 crate::MathEngine::Katex => Self::render_katex(latex, display),
99 crate::MathEngine::Mathml => Self::render_mathml(latex, display),
100 };
101 cache.lock().expect("math cache lock").insert(key, html.clone());
102 html
103 }
104
105 fn render_katex(latex: &str, display: bool) -> String {
106 let opts_result = if display { Self::display_opts() } else { Self::inline_opts() };
107 let opts = match opts_result {
108 Ok(o) => o,
109 Err(_) => return Self::error_span(latex, display),
116 };
117 match katex::render_with_opts(latex, opts) {
118 Ok(html) => html,
119 Err(_) => Self::error_span(latex, display),
120 }
121 }
122
123 fn render_mathml(latex: &str, display: bool) -> String {
124 use pulldown_latex::config::DisplayMode;
125 use pulldown_latex::{Parser, RenderConfig, Storage, mathml::push_mathml};
126 let storage = Storage::new();
127 let parser = Parser::new(latex, &storage);
128 let cfg = RenderConfig {
129 display_mode: if display { DisplayMode::Block } else { DisplayMode::Inline },
130 ..Default::default()
131 };
132 let mut out = String::new();
133 match push_mathml(&mut out, parser, cfg) {
134 Ok(()) => out,
135 Err(_) => Self::error_span(latex, display),
136 }
137 }
138
139 fn error_span(latex: &str, display: bool) -> String {
140 format!(
141 "<span class=\"math-error\">{}{}{}</span>",
142 if display { "$$" } else { "$" },
143 latex,
144 if display { "$$" } else { "$" }
145 )
146 }
147
148 pub fn set_engine(engine: crate::MathEngine) {
154 Self::engine_slot().store(engine_to_u8(engine), std::sync::atomic::Ordering::Release);
155 }
156
157 fn active_engine() -> crate::MathEngine {
158 u8_to_engine(Self::engine_slot().load(std::sync::atomic::Ordering::Acquire))
159 }
160
161 fn engine_slot() -> &'static std::sync::atomic::AtomicU8 {
162 static S: OnceLock<std::sync::atomic::AtomicU8> = OnceLock::new();
163 S.get_or_init(|| std::sync::atomic::AtomicU8::new(engine_to_u8(crate::MathEngine::default())))
164 }
165
166 fn cache() -> &'static Mutex<MathCache> {
167 static C: OnceLock<Mutex<MathCache>> = OnceLock::new();
168 C.get_or_init(|| Mutex::new(HashMap::new()))
169 }
170
171 #[allow(clippy::result_large_err)]
175 pub fn load_cache(path: &std::path::Path) -> DiagResult {
176 let s = match std::fs::read_to_string(path) {
177 Ok(s) => s,
178 Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(()),
179 Err(e) => return Err(diag!(Code::IoRead, format!("math cache read at {}: {}", path.display(), e))),
180 };
181
182 let rows = match serde_json::from_str::<Vec<(String, bool, u8, String)>>(&s) {
183 Ok(r) => r,
184 Err(_) => return Ok(()),
185 };
186
187 let mut cache = Self::cache().lock().map_err(|e| diag!(Code::LockPoisoned, format!("math cache lock: {}", e)))?;
188
189 for (latex, display, eng, html) in rows {
190 cache.entry((latex, display, u8_to_engine(eng))).or_insert(html);
191 }
192 Ok(())
193 }
194
195 #[allow(clippy::result_large_err)]
198 pub fn save_cache(path: &std::path::Path) -> DiagResult {
199 let cache = Self::cache().lock().expect("math cache lock");
200 let rows: Vec<(String, bool, u8, String)> = cache
201 .iter()
202 .map(|((latex, display, eng), html)| (latex.clone(), *display, engine_to_u8(*eng), html.clone()))
203 .collect();
204
205 let json =
206 serde_json::to_string(&rows).map_err(|e| diag!(Code::JsonSerialize, format!("math cache serialise: {}", e)))?;
207
208 if let Some(parent) = path.parent() {
209 std::fs::create_dir_all(parent)
210 .map_err(|e| diag!(Code::IoCreateDir, format!("math cache dir at {}: {}", parent.display(), e)))?;
211 }
212
213 std::fs::write(path, json)
214 .map_err(|e| diag!(Code::IoWrite, format!("math cache write at {}: {}", path.display(), e)))?;
215 Ok(())
216 }
217
218 #[allow(clippy::result_large_err)]
219 fn display_opts() -> DiagResult<&'static katex::Opts> {
220 static O: OnceLock<Result<katex::Opts, String>> = OnceLock::new();
221 let cached = O.get_or_init(|| {
222 katex::Opts::builder()
223 .display_mode(true)
224 .output_type(katex::OutputType::HtmlAndMathml)
225 .build()
226 .map_err(|e| e.to_string())
227 });
228 cached.as_ref().map_err(|e| diag!(Code::KatexOpts, format!("katex opts: {}", e)))
229 }
230
231 #[allow(clippy::result_large_err)]
237 fn inline_opts() -> DiagResult<&'static katex::Opts> {
238 static O: OnceLock<Result<katex::Opts, String>> = OnceLock::new();
239 let cached = O.get_or_init(|| {
240 katex::Opts::builder()
241 .display_mode(false)
242 .output_type(katex::OutputType::HtmlAndMathml)
243 .build()
244 .map_err(|e| e.to_string())
245 });
246 cached.as_ref().map_err(|e| diag!(Code::KatexOpts, format!("katex opts: {}", e)))
247 }
248
249 pub fn render_node(latex: &str, display: bool, span: &duck_diagnostic::Span) -> Node {
251 let mathml = Self::render(latex, display);
252 Node::JsxSelfClosing(JsxSelfClosing {
253 name: "MathMl".into(),
254 attrs: vec![JsxAttr { name: "mathml".into(), value: JsxAttrValue::String(mathml), span: span.clone() }],
255 span: span.clone(),
256 })
257 }
258
259 fn skip_frontmatter(source: &str, bytes: &[u8]) -> Option<usize> {
266 let fence = if bytes.starts_with(b"---\n") || bytes.starts_with(b"---\r\n") {
267 "---"
268 } else if bytes.starts_with(b"+++\n") || bytes.starts_with(b"+++\r\n") {
269 "+++"
270 } else {
271 return None;
272 };
273 let body_start = if bytes[3] == b'\r' { 5 } else { 4 };
274 let rest = &source[body_start..];
275 let mut search = 0usize;
278 while let Some(rel) = rest[search..].find(fence) {
279 let abs = search + rel;
280 let at_line_start = abs == 0 || rest.as_bytes()[abs - 1] == b'\n';
281 let after = abs + fence.len();
282 let terminates = after == rest.len() || rest.as_bytes()[after] == b'\n' || rest.as_bytes()[after] == b'\r';
283 if at_line_start && terminates {
284 let mut end = body_start + after;
285 if end < bytes.len() && bytes[end] == b'\r' {
286 end += 1;
287 }
288 if end < bytes.len() && bytes[end] == b'\n' {
289 end += 1;
290 }
291 return Some(end);
292 }
293 search = abs + fence.len();
294 }
295 None
296 }
297
298 fn skip_fenced_code(source: &str, bytes: &[u8], i: usize) -> Option<usize> {
299 if bytes[i] != b'`' || bytes.get(i + 1) != Some(&b'`') || bytes.get(i + 2) != Some(&b'`') {
300 return None;
301 }
302 let end = source[i + 3..].find("```").map(|p| i + 3 + p + 3).unwrap_or(bytes.len());
303 Some(end)
304 }
305
306 fn skip_inline_code(source: &str, bytes: &[u8], i: usize) -> Option<usize> {
307 if bytes[i] != b'`' {
308 return None;
309 }
310 let p = source[i + 1..].find('`')?;
311 Some(i + 1 + p + 1)
312 }
313
314 fn skip_jsx_tag(source: &str, bytes: &[u8], i: usize) -> Option<usize> {
315 if bytes[i] != b'<' {
316 return None;
317 }
318 let p = source[i + 1..].find('>')?;
319 Some(i + 1 + p + 1)
320 }
321
322 fn find_inline_close(inline: &str) -> Option<usize> {
323 let mut search = 0usize;
324 while search < inline.len() {
325 let rel = inline[search..].find(['$', '\n'])?;
326 let abs = search + rel;
327 if inline.as_bytes()[abs] == b'\n' {
328 return None;
329 }
330 if abs > 0 && inline.as_bytes()[abs - 1] == b'\\' {
331 search = abs + 1;
332 continue;
333 }
334 return Some(abs);
335 }
336 None
337 }
338
339 fn escape_jsx_attr(s: &str) -> String {
342 s.replace('&', "&").replace('"', """)
343 }
344}
345
346impl Transformer for Math {
347 fn name(&self) -> &str {
348 "math"
349 }
350 fn transform(&self, doc: &mut Document, _meta: &SourceMeta, _engine: &mut DiagnosticEngine<Code>) {
351 let mut v = Apply;
352 walk_root(&mut doc.children, &mut v);
353 }
354}
355
356struct Apply;
357
358impl Visitor for Apply {
359 fn visit_node(&mut self, node: &mut Node) -> NodeAction {
360 if let Node::Paragraph(p) = node
361 && let [Node::Text(t)] = p.children.as_slice()
362 && let Some(latex) = Math::unwrap_block(t.value.trim())
363 {
364 let span = t.span.clone();
365 return NodeAction::Replace(vec![Math::render_node(latex, true, &span)]);
366 }
367 let Node::Text(t) = node else { return NodeAction::Keep };
368 let Some(replacement) = Math::expand_inline(&t.value, &t.span) else { return NodeAction::Keep };
369 NodeAction::Replace(replacement)
370 }
371}
372
373impl Math {
374 fn unwrap_block(s: &str) -> Option<&str> {
375 let s = s.trim();
376 let inner = s.strip_prefix("$$")?.strip_suffix("$$")?;
377 Some(inner.trim())
378 }
379
380 fn expand_inline(text: &str, span: &duck_diagnostic::Span) -> Option<Vec<Node>> {
381 if !text.contains('$') {
382 return None;
383 }
384 let mut out: Vec<Node> = Vec::new();
385 let mut buf = String::new();
386 let bytes = text.as_bytes();
387 let mut i = 0;
388 let mut found_any = false;
389
390 while i < bytes.len() {
391 let c = bytes[i];
392 if c == b'\\' && i + 1 < bytes.len() && bytes[i + 1] == b'$' {
393 buf.push('$');
394 i += 2;
395 continue;
396 }
397 if c != b'$' {
398 let ch_len = utf8_char_len(bytes[i]);
399 buf.push_str(&text[i..i + ch_len]);
400 i += ch_len;
401 continue;
402 }
403 let (delim, display) = if i + 1 < bytes.len() && bytes[i + 1] == b'$' { ("$$", true) } else { ("$", false) };
404 let inner_start = i + delim.len();
405 let Some(close_off) = Self::find_unescaped(&text[inner_start..], delim) else {
406 buf.push('$');
407 i += 1;
408 continue;
409 };
410 let inner = &text[inner_start..inner_start + close_off];
411 if !buf.is_empty() {
412 out.push(Node::Text(Text { value: std::mem::take(&mut buf), span: span.clone() }));
413 }
414 out.push(Self::render_node(inner, display, span));
415 i = inner_start + close_off + delim.len();
416 found_any = true;
417 }
418
419 if !found_any {
420 return None;
421 }
422 if !buf.is_empty() {
423 out.push(Node::Text(Text { value: buf, span: span.clone() }));
424 }
425 Some(out)
426 }
427
428 fn find_unescaped(haystack: &str, delim: &str) -> Option<usize> {
429 let mut search_from = 0;
430 while search_from < haystack.len() {
431 let off = haystack[search_from..].find(delim)?;
432 let abs = search_from + off;
433 if abs > 0 && haystack.as_bytes()[abs - 1] == b'\\' {
434 search_from = abs + delim.len();
435 continue;
436 }
437 return Some(abs);
438 }
439 None
440 }
441}
442
443fn engine_to_u8(e: crate::MathEngine) -> u8 {
444 match e {
445 crate::MathEngine::Katex => 0,
446 crate::MathEngine::Mathml => 1,
447 }
448}
449
450fn u8_to_engine(b: u8) -> crate::MathEngine {
451 match b {
452 1 => crate::MathEngine::Mathml,
453 _ => crate::MathEngine::Katex,
454 }
455}
456
457fn utf8_char_len(b: u8) -> usize {
458 if b < 0x80 {
459 1
460 } else if b < 0xE0 {
461 2
462 } else if b < 0xF0 {
463 3
464 } else {
465 4
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472 use duck_diagnostic::Span;
473 use std::sync::Arc;
474
475 fn s() -> Span {
476 Span { file: Arc::from("<t>"), line: 1, column: 1, length: 0 }
477 }
478
479 #[test]
480 fn passthrough_when_no_dollars() {
481 assert!(Math::expand_inline("nothing here", &s()).is_none());
482 }
483
484 #[test]
485 fn inline_math_replaces_one_span() {
486 let r = Math::expand_inline("a $x+1$ b", &s()).expect("matched");
487 assert_eq!(r.len(), 3);
489 assert!(matches!(&r[0], Node::Text(t) if t.value == "a "));
490 assert!(matches!(&r[1], Node::JsxSelfClosing(e) if e.name == "MathMl"));
491 assert!(matches!(&r[2], Node::Text(t) if t.value == " b"));
492 }
493
494 #[test]
495 fn escaped_dollar_is_literal() {
496 let r = Math::expand_inline(r"price \$5", &s());
497 assert!(r.is_none() || matches!(&r.unwrap()[0], Node::Text(t) if t.value.contains("$5")));
498 }
499
500 #[test]
501 fn unmatched_dollar_left_alone() {
502 assert!(Math::expand_inline("a $ stray", &s()).is_none());
503 }
504
505 #[test]
506 fn block_math_unwraps() {
507 let s_ = s();
508 let mut p = Document {
509 children: vec![Node::Paragraph(Paragraph {
510 children: vec![Node::Text(Text { value: "$$ x = y $$".into(), span: s_.clone() })],
511 span: s_.clone(),
512 })],
513 span: s_,
514 };
515 let mut v = Apply;
516 walk_root(&mut p.children, &mut v);
517 assert_eq!(p.children.len(), 1);
518 if let Node::JsxSelfClosing(e) = &p.children[0] {
519 assert_eq!(e.name, "MathMl");
520 let mathml = e.attrs.iter().find(|a| a.name == "mathml").unwrap();
521 assert!(
522 matches!(&mathml.value, JsxAttrValue::String(s) if s.contains("<math") && s.contains("display=\"block\""))
523 );
524 } else {
525 panic!("expected MathMl element, got {:?}", p.children[0]);
526 }
527 }
528}