1use crate::pipeline::Transformer;
4use crate::visit::{NodeAction, Visitor, walk_root};
5use dmc_diagnostic::metadata::SourceMeta;
6use dmc_diagnostic::{Code, DiagResult};
7use dmc_parser::ast::*;
8use duck_diagnostic::{DiagnosticEngine, diag};
9use std::collections::HashMap;
10use std::sync::{Mutex, OnceLock};
11
12type MathCacheKey = (String, bool, crate::MathEngine);
13type MathCache = HashMap<MathCacheKey, String>;
14
15#[derive(Default, Debug)]
25pub struct Math;
26
27impl Math {
28 pub fn preprocess_source(source: &str) -> String {
31 let mut out = String::with_capacity(source.len());
32 let bytes = source.as_bytes();
33 let mut i = 0;
34 while i < bytes.len() {
35 if let Some(end) = Self::skip_fenced_code(source, bytes, i) {
36 out.push_str(&source[i..end]);
37 i = end;
38 continue;
39 }
40 if let Some(end) = Self::skip_inline_code(source, bytes, i) {
41 out.push_str(&source[i..end]);
42 i = end;
43 continue;
44 }
45 if let Some(end) = Self::skip_jsx_tag(source, bytes, i) {
46 out.push_str(&source[i..end]);
47 i = end;
48 continue;
49 }
50 if bytes[i] == b'\\' && bytes.get(i + 1) == Some(&b'$') {
51 out.push_str("\\$");
52 i += 2;
53 continue;
54 }
55 if bytes[i] == b'$' {
56 let display = bytes.get(i + 1) == Some(&b'$');
57 let delim_len = if display { 2 } else { 1 };
58 let body_start = i + delim_len;
59 let close_off =
60 if display { source[body_start..].find("$$") } else { Self::find_inline_close(&source[body_start..]) };
61 if let Some(off) = close_off {
62 let inner = &source[body_start..body_start + off];
63 if !inner.trim().is_empty() {
64 let rendered = Self::render(inner, display);
65 out.push_str(&format!("<MathMl mathml=\"{}\"/>", Self::escape_jsx_attr(&rendered)));
66 i = body_start + off + delim_len;
67 continue;
68 }
69 }
70 out.push('$');
71 i += 1;
72 continue;
73 }
74 let ch_len = utf8_char_len(bytes[i]);
75 out.push_str(&source[i..i + ch_len]);
76 i += ch_len;
77 }
78 out
79 }
80
81 pub fn render(latex: &str, display: bool) -> String {
87 let engine = Self::active_engine();
88 let cache = Self::cache();
89 let key = (latex.to_string(), display, engine);
90 if let Some(hit) = cache.lock().expect("math cache lock").get(&key) {
91 return hit.clone();
92 }
93 let html = match engine {
94 crate::MathEngine::Katex => Self::render_katex(latex, display),
95 crate::MathEngine::Mathml => Self::render_mathml(latex, display),
96 };
97 cache.lock().expect("math cache lock").insert(key, html.clone());
98 html
99 }
100
101 fn render_katex(latex: &str, display: bool) -> String {
102 let opts_result = if display { Self::display_opts() } else { Self::inline_opts() };
103 let opts = match opts_result {
104 Ok(o) => o,
105 Err(_) => return Self::error_span(latex, display),
112 };
113 match katex::render_with_opts(latex, opts) {
114 Ok(html) => html,
115 Err(_) => Self::error_span(latex, display),
116 }
117 }
118
119 fn render_mathml(latex: &str, display: bool) -> String {
120 use pulldown_latex::config::DisplayMode;
121 use pulldown_latex::{Parser, RenderConfig, Storage, mathml::push_mathml};
122 let storage = Storage::new();
123 let parser = Parser::new(latex, &storage);
124 let cfg = RenderConfig {
125 display_mode: if display { DisplayMode::Block } else { DisplayMode::Inline },
126 ..Default::default()
127 };
128 let mut out = String::new();
129 match push_mathml(&mut out, parser, cfg) {
130 Ok(()) => out,
131 Err(_) => Self::error_span(latex, display),
132 }
133 }
134
135 fn error_span(latex: &str, display: bool) -> String {
136 format!(
137 "<span class=\"math-error\">{}{}{}</span>",
138 if display { "$$" } else { "$" },
139 latex,
140 if display { "$$" } else { "$" }
141 )
142 }
143
144 pub fn set_engine(engine: crate::MathEngine) {
150 Self::engine_slot().store(engine_to_u8(engine), std::sync::atomic::Ordering::Release);
151 }
152
153 fn active_engine() -> crate::MathEngine {
154 u8_to_engine(Self::engine_slot().load(std::sync::atomic::Ordering::Acquire))
155 }
156
157 fn engine_slot() -> &'static std::sync::atomic::AtomicU8 {
158 static S: OnceLock<std::sync::atomic::AtomicU8> = OnceLock::new();
159 S.get_or_init(|| std::sync::atomic::AtomicU8::new(engine_to_u8(crate::MathEngine::default())))
160 }
161
162 fn cache() -> &'static Mutex<MathCache> {
163 static C: OnceLock<Mutex<MathCache>> = OnceLock::new();
164 C.get_or_init(|| Mutex::new(HashMap::new()))
165 }
166
167 #[allow(clippy::result_large_err)]
171 pub fn load_cache(path: &std::path::Path) -> DiagResult {
172 let s = match std::fs::read_to_string(path) {
173 Ok(s) => s,
174 Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(()),
175 Err(e) => return Err(diag!(Code::IoRead, format!("math cache read at {}: {}", path.display(), e))),
176 };
177
178 let rows = match serde_json::from_str::<Vec<(String, bool, u8, String)>>(&s) {
179 Ok(r) => r,
180 Err(_) => return Ok(()),
181 };
182
183 let mut cache = Self::cache().lock().map_err(|e| diag!(Code::LockPoisoned, format!("math cache lock: {}", e)))?;
184
185 for (latex, display, eng, html) in rows {
186 cache.entry((latex, display, u8_to_engine(eng))).or_insert(html);
187 }
188 Ok(())
189 }
190
191 #[allow(clippy::result_large_err)]
194 pub fn save_cache(path: &std::path::Path) -> DiagResult {
195 let cache = Self::cache().lock().expect("math cache lock");
196 let rows: Vec<(String, bool, u8, String)> = cache
197 .iter()
198 .map(|((latex, display, eng), html)| (latex.clone(), *display, engine_to_u8(*eng), html.clone()))
199 .collect();
200
201 let json =
202 serde_json::to_string(&rows).map_err(|e| diag!(Code::JsonSerialize, format!("math cache serialise: {}", e)))?;
203
204 if let Some(parent) = path.parent() {
205 std::fs::create_dir_all(parent)
206 .map_err(|e| diag!(Code::IoCreateDir, format!("math cache dir at {}: {}", parent.display(), e)))?;
207 }
208
209 std::fs::write(path, json)
210 .map_err(|e| diag!(Code::IoWrite, format!("math cache write at {}: {}", path.display(), e)))?;
211 Ok(())
212 }
213
214 #[allow(clippy::result_large_err)]
215 fn display_opts() -> DiagResult<&'static katex::Opts> {
216 static O: OnceLock<Result<katex::Opts, String>> = OnceLock::new();
217 let cached = O.get_or_init(|| {
218 katex::Opts::builder()
219 .display_mode(true)
220 .output_type(katex::OutputType::HtmlAndMathml)
221 .build()
222 .map_err(|e| e.to_string())
223 });
224 cached.as_ref().map_err(|e| diag!(Code::KatexOpts, format!("katex opts: {}", e)))
225 }
226
227 #[allow(clippy::result_large_err)]
233 fn inline_opts() -> DiagResult<&'static katex::Opts> {
234 static O: OnceLock<Result<katex::Opts, String>> = OnceLock::new();
235 let cached = O.get_or_init(|| {
236 katex::Opts::builder()
237 .display_mode(false)
238 .output_type(katex::OutputType::HtmlAndMathml)
239 .build()
240 .map_err(|e| e.to_string())
241 });
242 cached.as_ref().map_err(|e| diag!(Code::KatexOpts, format!("katex opts: {}", e)))
243 }
244
245 pub fn render_node(latex: &str, display: bool, span: &duck_diagnostic::Span) -> Node {
247 let mathml = Self::render(latex, display);
248 Node::JsxSelfClosing(JsxSelfClosing {
249 name: "MathMl".into(),
250 attrs: vec![JsxAttr { name: "mathml".into(), value: JsxAttrValue::String(mathml), span: span.clone() }],
251 span: span.clone(),
252 })
253 }
254
255 fn skip_fenced_code(source: &str, bytes: &[u8], i: usize) -> Option<usize> {
258 if bytes[i] != b'`' || bytes.get(i + 1) != Some(&b'`') || bytes.get(i + 2) != Some(&b'`') {
259 return None;
260 }
261 let end = source[i + 3..].find("```").map(|p| i + 3 + p + 3).unwrap_or(bytes.len());
262 Some(end)
263 }
264
265 fn skip_inline_code(source: &str, bytes: &[u8], i: usize) -> Option<usize> {
266 if bytes[i] != b'`' {
267 return None;
268 }
269 let p = source[i + 1..].find('`')?;
270 Some(i + 1 + p + 1)
271 }
272
273 fn skip_jsx_tag(source: &str, bytes: &[u8], i: usize) -> Option<usize> {
274 if bytes[i] != b'<' {
275 return None;
276 }
277 let p = source[i + 1..].find('>')?;
278 Some(i + 1 + p + 1)
279 }
280
281 fn find_inline_close(inline: &str) -> Option<usize> {
282 let mut search = 0usize;
283 while search < inline.len() {
284 let rel = inline[search..].find(['$', '\n'])?;
285 let abs = search + rel;
286 if inline.as_bytes()[abs] == b'\n' {
287 return None;
288 }
289 if abs > 0 && inline.as_bytes()[abs - 1] == b'\\' {
290 search = abs + 1;
291 continue;
292 }
293 return Some(abs);
294 }
295 None
296 }
297
298 fn escape_jsx_attr(s: &str) -> String {
301 s.replace('&', "&").replace('"', """)
302 }
303}
304
305impl Transformer for Math {
306 fn name(&self) -> &str {
307 "math"
308 }
309 fn transform(&self, doc: &mut Document, _meta: &SourceMeta, _engine: &mut DiagnosticEngine<Code>) {
310 let mut v = Apply;
311 walk_root(&mut doc.children, &mut v);
312 }
313}
314
315struct Apply;
316
317impl Visitor for Apply {
318 fn visit_node(&mut self, node: &mut Node) -> NodeAction {
319 if let Node::Paragraph(p) = node
320 && let [Node::Text(t)] = p.children.as_slice()
321 && let Some(latex) = Math::unwrap_block(t.value.trim())
322 {
323 let span = t.span.clone();
324 return NodeAction::Replace(vec![Math::render_node(latex, true, &span)]);
325 }
326 let Node::Text(t) = node else { return NodeAction::Keep };
327 let Some(replacement) = Math::expand_inline(&t.value, &t.span) else { return NodeAction::Keep };
328 NodeAction::Replace(replacement)
329 }
330}
331
332impl Math {
333 fn unwrap_block(s: &str) -> Option<&str> {
334 let s = s.trim();
335 let inner = s.strip_prefix("$$")?.strip_suffix("$$")?;
336 Some(inner.trim())
337 }
338
339 fn expand_inline(text: &str, span: &duck_diagnostic::Span) -> Option<Vec<Node>> {
340 if !text.contains('$') {
341 return None;
342 }
343 let mut out: Vec<Node> = Vec::new();
344 let mut buf = String::new();
345 let bytes = text.as_bytes();
346 let mut i = 0;
347 let mut found_any = false;
348
349 while i < bytes.len() {
350 let c = bytes[i];
351 if c == b'\\' && i + 1 < bytes.len() && bytes[i + 1] == b'$' {
352 buf.push('$');
353 i += 2;
354 continue;
355 }
356 if c != b'$' {
357 let ch_len = utf8_char_len(bytes[i]);
358 buf.push_str(&text[i..i + ch_len]);
359 i += ch_len;
360 continue;
361 }
362 let (delim, display) = if i + 1 < bytes.len() && bytes[i + 1] == b'$' { ("$$", true) } else { ("$", false) };
363 let inner_start = i + delim.len();
364 let Some(close_off) = Self::find_unescaped(&text[inner_start..], delim) else {
365 buf.push('$');
366 i += 1;
367 continue;
368 };
369 let inner = &text[inner_start..inner_start + close_off];
370 if !buf.is_empty() {
371 out.push(Node::Text(Text { value: std::mem::take(&mut buf), span: span.clone() }));
372 }
373 out.push(Self::render_node(inner, display, span));
374 i = inner_start + close_off + delim.len();
375 found_any = true;
376 }
377
378 if !found_any {
379 return None;
380 }
381 if !buf.is_empty() {
382 out.push(Node::Text(Text { value: buf, span: span.clone() }));
383 }
384 Some(out)
385 }
386
387 fn find_unescaped(haystack: &str, delim: &str) -> Option<usize> {
388 let mut search_from = 0;
389 while search_from < haystack.len() {
390 let off = haystack[search_from..].find(delim)?;
391 let abs = search_from + off;
392 if abs > 0 && haystack.as_bytes()[abs - 1] == b'\\' {
393 search_from = abs + delim.len();
394 continue;
395 }
396 return Some(abs);
397 }
398 None
399 }
400}
401
402fn engine_to_u8(e: crate::MathEngine) -> u8 {
403 match e {
404 crate::MathEngine::Katex => 0,
405 crate::MathEngine::Mathml => 1,
406 }
407}
408
409fn u8_to_engine(b: u8) -> crate::MathEngine {
410 match b {
411 1 => crate::MathEngine::Mathml,
412 _ => crate::MathEngine::Katex,
413 }
414}
415
416fn utf8_char_len(b: u8) -> usize {
417 if b < 0x80 {
418 1
419 } else if b < 0xE0 {
420 2
421 } else if b < 0xF0 {
422 3
423 } else {
424 4
425 }
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431 use duck_diagnostic::Span;
432 use std::sync::Arc;
433
434 fn s() -> Span {
435 Span { file: Arc::from("<t>"), line: 1, column: 1, length: 0 }
436 }
437
438 #[test]
439 fn passthrough_when_no_dollars() {
440 assert!(Math::expand_inline("nothing here", &s()).is_none());
441 }
442
443 #[test]
444 fn inline_math_replaces_one_span() {
445 let r = Math::expand_inline("a $x+1$ b", &s()).expect("matched");
446 assert_eq!(r.len(), 3);
448 assert!(matches!(&r[0], Node::Text(t) if t.value == "a "));
449 assert!(matches!(&r[1], Node::JsxSelfClosing(e) if e.name == "MathMl"));
450 assert!(matches!(&r[2], Node::Text(t) if t.value == " b"));
451 }
452
453 #[test]
454 fn escaped_dollar_is_literal() {
455 let r = Math::expand_inline(r"price \$5", &s());
456 assert!(r.is_none() || matches!(&r.unwrap()[0], Node::Text(t) if t.value.contains("$5")));
457 }
458
459 #[test]
460 fn unmatched_dollar_left_alone() {
461 assert!(Math::expand_inline("a $ stray", &s()).is_none());
462 }
463
464 #[test]
465 fn block_math_unwraps() {
466 let s_ = s();
467 let mut p = Document {
468 children: vec![Node::Paragraph(Paragraph {
469 children: vec![Node::Text(Text { value: "$$ x = y $$".into(), span: s_.clone() })],
470 span: s_.clone(),
471 })],
472 span: s_,
473 };
474 let mut v = Apply;
475 walk_root(&mut p.children, &mut v);
476 assert_eq!(p.children.len(), 1);
477 if let Node::JsxSelfClosing(e) = &p.children[0] {
478 assert_eq!(e.name, "MathMl");
479 let mathml = e.attrs.iter().find(|a| a.name == "mathml").unwrap();
480 assert!(
481 matches!(&mathml.value, JsxAttrValue::String(s) if s.contains("<math") && s.contains("display=\"block\""))
482 );
483 } else {
484 panic!("expected MathMl element, got {:?}", p.children[0]);
485 }
486 }
487}