1use crate::pipeline::Transformer;
13use crate::visit::{NodeAction, Visitor, walk_root};
14use dmc_diagnostic::Code;
15use dmc_diagnostic::metadata::SourceMeta;
16use dmc_parser::ast::*;
17use duck_diagnostic::DiagnosticEngine;
18use std::collections::HashMap;
19use std::sync::{Mutex, OnceLock};
20
21type MathCacheKey = (String, bool, crate::MathEngine);
22type MathCache = HashMap<MathCacheKey, String>;
23
24#[derive(Default, Debug)]
34pub struct Math;
35
36impl Math {
37 pub fn preprocess_source(source: &str) -> String {
40 let mut out = String::with_capacity(source.len());
41 let bytes = source.as_bytes();
42 let mut i = 0;
43 while i < bytes.len() {
44 if let Some(end) = Self::skip_fenced_code(source, bytes, i) {
45 out.push_str(&source[i..end]);
46 i = end;
47 continue;
48 }
49 if let Some(end) = Self::skip_inline_code(source, bytes, i) {
50 out.push_str(&source[i..end]);
51 i = end;
52 continue;
53 }
54 if let Some(end) = Self::skip_jsx_tag(source, bytes, i) {
55 out.push_str(&source[i..end]);
56 i = end;
57 continue;
58 }
59 if bytes[i] == b'\\' && bytes.get(i + 1) == Some(&b'$') {
60 out.push_str("\\$");
61 i += 2;
62 continue;
63 }
64 if bytes[i] == b'$' {
65 let display = bytes.get(i + 1) == Some(&b'$');
66 let delim_len = if display { 2 } else { 1 };
67 let body_start = i + delim_len;
68 let close_off =
69 if display { source[body_start..].find("$$") } else { Self::find_inline_close(&source[body_start..]) };
70 if let Some(off) = close_off {
71 let inner = &source[body_start..body_start + off];
72 if !inner.trim().is_empty() {
73 let rendered = Self::render(inner, display);
74 out.push_str(&format!("<MathMl mathml=\"{}\"/>", Self::escape_jsx_attr(&rendered)));
75 i = body_start + off + delim_len;
76 continue;
77 }
78 }
79 out.push('$');
80 i += 1;
81 continue;
82 }
83 let ch_len = utf8_char_len(bytes[i]);
84 out.push_str(&source[i..i + ch_len]);
85 i += ch_len;
86 }
87 out
88 }
89
90 pub fn render(latex: &str, display: bool) -> String {
96 let engine = Self::active_engine();
97 let cache = Self::cache();
98 let key = (latex.to_string(), display, engine);
99 if let Some(hit) = cache.lock().expect("math cache lock").get(&key) {
100 return hit.clone();
101 }
102 let html = match engine {
103 crate::MathEngine::Katex => Self::render_katex(latex, display),
104 crate::MathEngine::Mathml => Self::render_mathml(latex, display),
105 };
106 cache.lock().expect("math cache lock").insert(key, html.clone());
107 html
108 }
109
110 fn render_katex(latex: &str, display: bool) -> String {
111 let opts = if display { Self::display_opts() } else { Self::inline_opts() };
112 match katex::render_with_opts(latex, opts) {
113 Ok(html) => html,
114 Err(_) => Self::error_span(latex, display),
115 }
116 }
117
118 fn render_mathml(latex: &str, display: bool) -> String {
119 use pulldown_latex::config::DisplayMode;
120 use pulldown_latex::{Parser, RenderConfig, Storage, mathml::push_mathml};
121 let storage = Storage::new();
122 let parser = Parser::new(latex, &storage);
123 let cfg = RenderConfig {
124 display_mode: if display { DisplayMode::Block } else { DisplayMode::Inline },
125 ..Default::default()
126 };
127 let mut out = String::new();
128 match push_mathml(&mut out, parser, cfg) {
129 Ok(()) => out,
130 Err(_) => Self::error_span(latex, display),
131 }
132 }
133
134 fn error_span(latex: &str, display: bool) -> String {
135 format!(
136 "<span class=\"math-error\">{}{}{}</span>",
137 if display { "$$" } else { "$" },
138 latex,
139 if display { "$$" } else { "$" }
140 )
141 }
142
143 pub fn set_engine(engine: crate::MathEngine) {
149 Self::engine_slot().store(engine_to_u8(engine), std::sync::atomic::Ordering::Release);
150 }
151
152 fn active_engine() -> crate::MathEngine {
153 u8_to_engine(Self::engine_slot().load(std::sync::atomic::Ordering::Acquire))
154 }
155
156 fn engine_slot() -> &'static std::sync::atomic::AtomicU8 {
157 static S: OnceLock<std::sync::atomic::AtomicU8> = OnceLock::new();
158 S.get_or_init(|| std::sync::atomic::AtomicU8::new(engine_to_u8(crate::MathEngine::default())))
159 }
160
161 fn cache() -> &'static Mutex<MathCache> {
162 static C: OnceLock<Mutex<MathCache>> = OnceLock::new();
163 C.get_or_init(|| Mutex::new(HashMap::new()))
164 }
165
166 pub fn load_cache(path: &std::path::Path) {
170 let Ok(s) = std::fs::read_to_string(path) else { return };
171 let Ok(rows) = serde_json::from_str::<Vec<(String, bool, u8, String)>>(&s) else { return };
172 let mut cache = Self::cache().lock().expect("math cache lock");
173 for (latex, display, eng, html) in rows {
174 cache.entry((latex, display, u8_to_engine(eng))).or_insert(html);
175 }
176 }
177
178 pub fn save_cache(path: &std::path::Path) {
181 let cache = Self::cache().lock().expect("math cache lock");
182 let rows: Vec<(String, bool, u8, String)> = cache
183 .iter()
184 .map(|((latex, display, eng), html)| (latex.clone(), *display, engine_to_u8(*eng), html.clone()))
185 .collect();
186 let Ok(json) = serde_json::to_string(&rows) else { return };
187 if let Some(parent) = path.parent() {
188 let _ = std::fs::create_dir_all(parent);
189 }
190 let _ = std::fs::write(path, json);
191 }
192
193 fn display_opts() -> &'static katex::Opts {
194 static O: OnceLock<katex::Opts> = OnceLock::new();
195 O.get_or_init(|| {
196 katex::Opts::builder()
197 .display_mode(true)
198 .output_type(katex::OutputType::HtmlAndMathml)
199 .build()
200 .expect("katex opts")
201 })
202 }
203
204 fn inline_opts() -> &'static katex::Opts {
205 static O: OnceLock<katex::Opts> = OnceLock::new();
206 O.get_or_init(|| {
207 katex::Opts::builder()
208 .display_mode(false)
209 .output_type(katex::OutputType::HtmlAndMathml)
210 .build()
211 .expect("katex opts")
212 })
213 }
214
215 pub fn render_node(latex: &str, display: bool, span: &duck_diagnostic::Span) -> Node {
217 let mathml = Self::render(latex, display);
218 Node::JsxSelfClosing(JsxSelfClosing {
219 name: "MathMl".into(),
220 attrs: vec![JsxAttr { name: "mathml".into(), value: JsxAttrValue::String(mathml), span: span.clone() }],
221 span: span.clone(),
222 })
223 }
224
225 fn skip_fenced_code(source: &str, bytes: &[u8], i: usize) -> Option<usize> {
228 if bytes[i] != b'`' || bytes.get(i + 1) != Some(&b'`') || bytes.get(i + 2) != Some(&b'`') {
229 return None;
230 }
231 let end = source[i + 3..].find("```").map(|p| i + 3 + p + 3).unwrap_or(bytes.len());
232 Some(end)
233 }
234
235 fn skip_inline_code(source: &str, bytes: &[u8], i: usize) -> Option<usize> {
236 if bytes[i] != b'`' {
237 return None;
238 }
239 let p = source[i + 1..].find('`')?;
240 Some(i + 1 + p + 1)
241 }
242
243 fn skip_jsx_tag(source: &str, bytes: &[u8], i: usize) -> Option<usize> {
244 if bytes[i] != b'<' {
245 return None;
246 }
247 let p = source[i + 1..].find('>')?;
248 Some(i + 1 + p + 1)
249 }
250
251 fn find_inline_close(inline: &str) -> Option<usize> {
252 let mut search = 0usize;
253 while search < inline.len() {
254 let rel = inline[search..].find(['$', '\n'])?;
255 let abs = search + rel;
256 if inline.as_bytes()[abs] == b'\n' {
257 return None;
258 }
259 if abs > 0 && inline.as_bytes()[abs - 1] == b'\\' {
260 search = abs + 1;
261 continue;
262 }
263 return Some(abs);
264 }
265 None
266 }
267
268 fn escape_jsx_attr(s: &str) -> String {
271 s.replace('&', "&").replace('"', """)
272 }
273}
274
275impl Transformer for Math {
276 fn name(&self) -> &str {
277 "math"
278 }
279 fn transform(&self, doc: &mut Document, _meta: &SourceMeta, _engine: &mut DiagnosticEngine<Code>) {
280 let mut v = Apply;
281 walk_root(&mut doc.children, &mut v);
282 }
283}
284
285struct Apply;
286
287impl Visitor for Apply {
288 fn visit_node(&mut self, node: &mut Node) -> NodeAction {
289 if let Node::Paragraph(p) = node
290 && let [Node::Text(t)] = p.children.as_slice()
291 && let Some(latex) = Math::unwrap_block(t.value.trim())
292 {
293 let span = t.span.clone();
294 return NodeAction::Replace(vec![Math::render_node(latex, true, &span)]);
295 }
296 let Node::Text(t) = node else { return NodeAction::Keep };
297 let Some(replacement) = Math::expand_inline(&t.value, &t.span) else { return NodeAction::Keep };
298 NodeAction::Replace(replacement)
299 }
300}
301
302impl Math {
303 fn unwrap_block(s: &str) -> Option<&str> {
304 let s = s.trim();
305 let inner = s.strip_prefix("$$")?.strip_suffix("$$")?;
306 Some(inner.trim())
307 }
308
309 fn expand_inline(text: &str, span: &duck_diagnostic::Span) -> Option<Vec<Node>> {
310 if !text.contains('$') {
311 return None;
312 }
313 let mut out: Vec<Node> = Vec::new();
314 let mut buf = String::new();
315 let bytes = text.as_bytes();
316 let mut i = 0;
317 let mut found_any = false;
318
319 while i < bytes.len() {
320 let c = bytes[i];
321 if c == b'\\' && i + 1 < bytes.len() && bytes[i + 1] == b'$' {
322 buf.push('$');
323 i += 2;
324 continue;
325 }
326 if c != b'$' {
327 let ch_len = utf8_char_len(bytes[i]);
328 buf.push_str(&text[i..i + ch_len]);
329 i += ch_len;
330 continue;
331 }
332 let (delim, display) = if i + 1 < bytes.len() && bytes[i + 1] == b'$' { ("$$", true) } else { ("$", false) };
333 let inner_start = i + delim.len();
334 let Some(close_off) = Self::find_unescaped(&text[inner_start..], delim) else {
335 buf.push('$');
336 i += 1;
337 continue;
338 };
339 let inner = &text[inner_start..inner_start + close_off];
340 if !buf.is_empty() {
341 out.push(Node::Text(Text { value: std::mem::take(&mut buf), span: span.clone() }));
342 }
343 out.push(Self::render_node(inner, display, span));
344 i = inner_start + close_off + delim.len();
345 found_any = true;
346 }
347
348 if !found_any {
349 return None;
350 }
351 if !buf.is_empty() {
352 out.push(Node::Text(Text { value: buf, span: span.clone() }));
353 }
354 Some(out)
355 }
356
357 fn find_unescaped(haystack: &str, delim: &str) -> Option<usize> {
358 let mut search_from = 0;
359 while search_from < haystack.len() {
360 let off = haystack[search_from..].find(delim)?;
361 let abs = search_from + off;
362 if abs > 0 && haystack.as_bytes()[abs - 1] == b'\\' {
363 search_from = abs + delim.len();
364 continue;
365 }
366 return Some(abs);
367 }
368 None
369 }
370}
371
372fn engine_to_u8(e: crate::MathEngine) -> u8 {
373 match e {
374 crate::MathEngine::Katex => 0,
375 crate::MathEngine::Mathml => 1,
376 }
377}
378
379fn u8_to_engine(b: u8) -> crate::MathEngine {
380 match b {
381 1 => crate::MathEngine::Mathml,
382 _ => crate::MathEngine::Katex,
383 }
384}
385
386fn utf8_char_len(b: u8) -> usize {
387 if b < 0x80 {
388 1
389 } else if b < 0xE0 {
390 2
391 } else if b < 0xF0 {
392 3
393 } else {
394 4
395 }
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401 use duck_diagnostic::Span;
402 use std::sync::Arc;
403
404 fn s() -> Span {
405 Span { file: Arc::from("<t>"), line: 1, column: 1, length: 0 }
406 }
407
408 #[test]
409 fn passthrough_when_no_dollars() {
410 assert!(Math::expand_inline("nothing here", &s()).is_none());
411 }
412
413 #[test]
414 fn inline_math_replaces_one_span() {
415 let r = Math::expand_inline("a $x+1$ b", &s()).expect("matched");
416 assert_eq!(r.len(), 3);
418 assert!(matches!(&r[0], Node::Text(t) if t.value == "a "));
419 assert!(matches!(&r[1], Node::JsxSelfClosing(e) if e.name == "MathMl"));
420 assert!(matches!(&r[2], Node::Text(t) if t.value == " b"));
421 }
422
423 #[test]
424 fn escaped_dollar_is_literal() {
425 let r = Math::expand_inline(r"price \$5", &s());
426 assert!(r.is_none() || matches!(&r.unwrap()[0], Node::Text(t) if t.value.contains("$5")));
427 }
428
429 #[test]
430 fn unmatched_dollar_left_alone() {
431 assert!(Math::expand_inline("a $ stray", &s()).is_none());
432 }
433
434 #[test]
435 fn block_math_unwraps() {
436 let s_ = s();
437 let mut p = Document {
438 children: vec![Node::Paragraph(Paragraph {
439 children: vec![Node::Text(Text { value: "$$ x = y $$".into(), span: s_.clone() })],
440 span: s_.clone(),
441 })],
442 span: s_,
443 };
444 let mut v = Apply;
445 walk_root(&mut p.children, &mut v);
446 assert_eq!(p.children.len(), 1);
447 if let Node::JsxSelfClosing(e) = &p.children[0] {
448 assert_eq!(e.name, "MathMl");
449 let mathml = e.attrs.iter().find(|a| a.name == "mathml").unwrap();
450 assert!(
451 matches!(&mathml.value, JsxAttrValue::String(s) if s.contains("<math") && s.contains("display=\"block\""))
452 );
453 } else {
454 panic!("expected MathMl element, got {:?}", p.children[0]);
455 }
456 }
457}