1pub mod parser;
2pub mod symbols;
3pub mod tokenizer;
15
16use std::collections::HashMap;
17
18pub use rdx_ast::{
21 AccentKind, AlignRow, CaseRow, ColumnAlign, Delimiter, FracStyle, LimitStyle, MathExpr,
22 MathFont, MathOperator, MathSpace, MathStyle, MatrixDelimiters, OperatorKind, SmashMode,
23};
24
25pub struct MacroDef {
29 pub arity: u8,
31 pub template: String,
33}
34
35pub fn parse(input: &str) -> MathExpr {
44 let tokens = tokenizer::tokenize(input);
45 let mut ts = tokenizer::TokenStream::new(tokens);
46 parser::parse_expr(&mut ts)
47}
48
49pub fn parse_with_macros(input: &str, macros: &HashMap<String, MacroDef>) -> MathExpr {
56 match expand_macros(input, macros, 64) {
57 Ok(expanded) => parse(&expanded),
58 Err(msg) => MathExpr::Error {
59 raw: input.to_string(),
60 message: msg,
61 },
62 }
63}
64
65fn expand_macros(
69 input: &str,
70 macros: &HashMap<String, MacroDef>,
71 max_depth: usize,
72) -> Result<String, String> {
73 if max_depth == 0 {
74 return Err(
75 "macro expansion depth limit (64) exceeded — possible infinite loop".to_string(),
76 );
77 }
78
79 let mut result = String::with_capacity(input.len());
80 let chars: Vec<char> = input.chars().collect();
81 let n = chars.len();
82 let mut i = 0;
83
84 while i < n {
85 if chars[i] != '\\' {
86 result.push(chars[i]);
87 i += 1;
88 continue;
89 }
90
91 let macro_start = i;
93 i += 1; if i >= n {
96 result.push('\\');
97 continue;
98 }
99
100 let name_start = i;
102 if chars[i].is_ascii_alphabetic() {
103 while i < n && chars[i].is_ascii_alphabetic() {
104 i += 1;
105 }
106 } else {
107 i += 1;
109 }
110 let cmd_name: String = chars[name_start..i].iter().collect();
111 let full_name = format!("\\{cmd_name}");
112
113 if let Some(def) = macros.get(&full_name) {
114 let mut args: Vec<String> = Vec::new();
116 let mut j = i;
117
118 for _ in 0..def.arity {
119 while j < n && chars[j].is_ascii_whitespace() {
121 j += 1;
122 }
123 if j >= n {
124 break;
125 }
126 if chars[j] == '{' {
127 j += 1; let arg_start = j;
130 let mut depth = 1usize;
131 while j < n && depth > 0 {
132 match chars[j] {
133 '{' => depth += 1,
134 '}' => depth -= 1,
135 _ => {}
136 }
137 if depth > 0 {
138 j += 1;
139 } else {
140 break;
142 }
143 }
144 let arg: String = chars[arg_start..j].iter().collect();
145 if j < n && chars[j] == '}' {
146 j += 1; }
148 args.push(arg);
149 } else {
150 args.push(chars[j].to_string());
152 j += 1;
153 }
154 }
155 i = j;
156
157 let mut expansion = def.template.clone();
159 for (k, arg) in args.iter().enumerate() {
160 let placeholder = format!("#{}", k + 1);
161 expansion = expansion.replace(&placeholder, arg);
162 }
163
164 let sub = expand_macros(&expansion, macros, max_depth - 1)?;
166 result.push_str(&sub);
167 } else {
168 let raw: String = chars[macro_start..i].iter().collect();
170 result.push_str(&raw);
171 }
172 }
173
174 Ok(result)
175}
176
177#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[test]
184 fn simple_fraction() {
185 let expr = parse(r"\frac{a}{b}");
186 assert_eq!(
187 expr,
188 MathExpr::Frac {
189 numerator: Box::new(MathExpr::Ident {
190 value: "a".to_string()
191 }),
192 denominator: Box::new(MathExpr::Ident {
193 value: "b".to_string()
194 }),
195 style: FracStyle::Auto,
196 }
197 );
198 }
199
200 #[test]
201 fn superscript() {
202 let expr = parse("x^2");
203 assert_eq!(
204 expr,
205 MathExpr::Superscript {
206 base: Box::new(MathExpr::Ident {
207 value: "x".to_string()
208 }),
209 script: Box::new(MathExpr::Number {
210 value: "2".to_string()
211 }),
212 }
213 );
214 }
215
216 #[test]
217 fn subscript_superscript() {
218 let expr = parse("x_i^2");
219 assert_eq!(
220 expr,
221 MathExpr::Subsuperscript {
222 base: Box::new(MathExpr::Ident {
223 value: "x".to_string()
224 }),
225 sub: Box::new(MathExpr::Ident {
226 value: "i".to_string()
227 }),
228 sup: Box::new(MathExpr::Number {
229 value: "2".to_string()
230 }),
231 }
232 );
233 }
234
235 #[test]
236 fn sum_with_limits() {
237 let expr = parse(r"\sum_{i=0}^{n} a_i");
238 assert!(matches!(
239 expr,
240 MathExpr::Row { .. } ));
242 }
243
244 #[test]
245 fn nested_fractions() {
246 let expr = parse(r"\frac{\frac{a}{b}}{c}");
247 assert!(matches!(expr, MathExpr::Frac { .. }));
248 if let MathExpr::Frac { numerator, .. } = &expr {
249 assert!(matches!(**numerator, MathExpr::Frac { .. }));
250 }
251 }
252
253 #[test]
254 fn left_right_delimiters() {
255 let expr = parse(r"\left( \frac{a}{b} \right)");
256 assert!(matches!(
257 expr,
258 MathExpr::Fenced {
259 open: Delimiter::Paren,
260 close: Delimiter::Paren,
261 ..
262 }
263 ));
264 }
265
266 #[test]
267 fn sqrt_with_index() {
268 let expr = parse(r"\sqrt[3]{x}");
269 assert!(matches!(expr, MathExpr::Sqrt { index: Some(_), .. }));
270 }
271
272 #[test]
273 fn unknown_command_error_recovery() {
274 let expr = parse(r"\frac{a}{\unknowncmd}");
275 assert!(matches!(expr, MathExpr::Frac { .. }));
277 if let MathExpr::Frac { denominator, .. } = expr {
279 assert!(
280 matches!(*denominator, MathExpr::Error { .. }),
281 "expected Error node for unknown command, got {:?}",
282 *denominator
283 );
284 }
285 }
286
287 #[test]
288 fn greek_letters() {
289 let alpha = parse(r"\alpha");
290 assert_eq!(
291 alpha,
292 MathExpr::Ident {
293 value: "α".to_string()
294 }
295 );
296
297 let beta = parse(r"\beta");
298 assert_eq!(
299 beta,
300 MathExpr::Ident {
301 value: "β".to_string()
302 }
303 );
304
305 let expr = parse(r"\alpha + \beta");
307 assert!(matches!(expr, MathExpr::Row { .. }));
308 }
309
310 #[test]
311 fn text_in_math() {
312 let expr = parse(r"\text{hello world}");
313 assert_eq!(
314 expr,
315 MathExpr::Text {
316 value: "hello world".to_string()
317 }
318 );
319 }
320
321 #[test]
322 fn macro_expansion_nullary() {
323 let mut macros = HashMap::new();
324 macros.insert(
325 "\\R".to_string(),
326 MacroDef {
327 arity: 0,
328 template: "\\mathbb{R}".to_string(),
329 },
330 );
331 let expr = parse_with_macros(r"x \in \R", ¯os);
332 assert!(matches!(expr, MathExpr::Row { .. }));
334 if let MathExpr::Row { children } = &expr {
335 let last = children.last().unwrap();
336 assert!(
337 matches!(
338 last,
339 MathExpr::FontOverride {
340 font: MathFont::Blackboard,
341 ..
342 }
343 ),
344 "expected FontOverride(Blackboard, ...), got {:?}",
345 last
346 );
347 }
348 }
349
350 #[test]
351 fn macro_expansion_with_arg() {
352 let mut macros = HashMap::new();
353 macros.insert(
354 "\\norm".to_string(),
355 MacroDef {
356 arity: 1,
357 template: "\\left\\lVert #1 \\right\\rVert".to_string(),
358 },
359 );
360 let expr = parse_with_macros(r"\norm{x+y}", ¯os);
361 assert!(
363 matches!(expr, MathExpr::Fenced { .. }),
364 "expected Fenced, got {:?}",
365 expr
366 );
367 }
368
369 #[test]
370 fn empty_input() {
371 let expr = parse("");
372 assert_eq!(
373 expr,
374 MathExpr::Row {
375 children: Vec::new()
376 }
377 );
378 }
379
380 #[test]
381 fn spacing_commands() {
382 let thin = parse(r"\,");
383 assert_eq!(thin, MathExpr::Space(MathSpace::Thin));
384
385 let quad = parse(r"\quad");
386 assert_eq!(quad, MathExpr::Space(MathSpace::Quad));
387
388 let expr = parse(r"a \, b \quad c");
390 assert!(matches!(expr, MathExpr::Row { .. }));
391 }
392
393 #[test]
394 fn relational_operators() {
395 let leq = parse(r"\leq");
396 assert!(matches!(
397 leq,
398 MathExpr::Operator(MathOperator {
399 kind: OperatorKind::Relation,
400 ..
401 })
402 ));
403
404 let neq = parse(r"\neq");
405 assert!(matches!(
406 neq,
407 MathExpr::Operator(MathOperator {
408 kind: OperatorKind::Relation,
409 ..
410 })
411 ));
412 }
413
414 #[test]
415 fn sum_with_sub_and_sup() {
416 let expr = parse(r"\sum_{i=0}^{n}");
417 assert!(matches!(
418 expr,
419 MathExpr::BigOperator {
420 lower: Some(_),
421 upper: Some(_),
422 ..
423 }
424 ));
425 }
426
427 #[test]
428 fn macro_expansion_depth_limit() {
429 let mut macros = HashMap::new();
431 macros.insert(
432 "\\bad".to_string(),
433 MacroDef {
434 arity: 0,
435 template: "\\bad".to_string(),
436 },
437 );
438 let expr = parse_with_macros(r"\bad", ¯os);
439 assert!(
441 matches!(expr, MathExpr::Error { .. }),
442 "expected Error for infinite macro, got {:?}",
443 expr
444 );
445 }
446
447 #[test]
448 fn all_greek_lowercase() {
449 let letters = [
450 "alpha",
451 "beta",
452 "gamma",
453 "delta",
454 "epsilon",
455 "varepsilon",
456 "zeta",
457 "eta",
458 "theta",
459 "vartheta",
460 "iota",
461 "kappa",
462 "lambda",
463 "mu",
464 "nu",
465 "xi",
466 "pi",
467 "varpi",
468 "rho",
469 "varrho",
470 "sigma",
471 "varsigma",
472 "tau",
473 "upsilon",
474 "phi",
475 "varphi",
476 "chi",
477 "psi",
478 "omega",
479 ];
480 for name in &letters {
481 let expr = parse(&format!("\\{name}"));
482 assert!(
483 matches!(expr, MathExpr::Ident { .. }),
484 "\\{name} should be Ident, got {:?}",
485 expr
486 );
487 }
488 }
489
490 #[test]
491 fn all_greek_uppercase() {
492 let letters = [
493 "Gamma", "Delta", "Theta", "Lambda", "Xi", "Pi", "Sigma", "Upsilon", "Phi", "Psi",
494 "Omega",
495 ];
496 for name in &letters {
497 let expr = parse(&format!("\\{name}"));
498 assert!(
499 matches!(expr, MathExpr::Ident { .. }),
500 "\\{name} should be Ident, got {:?}",
501 expr
502 );
503 }
504 }
505
506 #[test]
507 fn all_tier1_operators() {
508 let ops = [
509 r"\times",
510 r"\cdot",
511 r"\pm",
512 r"\mp",
513 r"\div",
514 r"\neq",
515 r"\leq",
516 r"\geq",
517 r"\approx",
518 r"\equiv",
519 r"\sim",
520 r"\cong",
521 r"\propto",
522 r"\in",
523 r"\notin",
524 r"\subset",
525 r"\supset",
526 r"\cup",
527 r"\cap",
528 r"\land",
529 r"\lor",
530 r"\neg",
531 r"\implies",
532 r"\iff",
533 ];
534 for op in &ops {
535 let expr = parse(op);
536 assert!(
537 matches!(expr, MathExpr::Operator(_)),
538 "{op} should be Operator, got {:?}",
539 expr
540 );
541 }
542 }
543
544 #[test]
545 fn all_large_operators() {
546 let ops = [
547 r"\sum", r"\prod", r"\int", r"\iint", r"\iiint", r"\oint", r"\bigcup", r"\bigcap",
548 ];
549 for op in &ops {
550 let expr = parse(op);
551 assert!(
552 matches!(expr, MathExpr::BigOperator { .. }),
553 "{op} should be BigOperator, got {:?}",
554 expr
555 );
556 }
557 }
558
559 #[test]
560 fn frac_styles() {
561 let auto = parse(r"\frac{1}{2}");
562 assert!(matches!(
563 auto,
564 MathExpr::Frac {
565 style: FracStyle::Auto,
566 ..
567 }
568 ));
569
570 let display = parse(r"\dfrac{1}{2}");
571 assert!(matches!(
572 display,
573 MathExpr::Frac {
574 style: FracStyle::Display,
575 ..
576 }
577 ));
578
579 let text = parse(r"\tfrac{1}{2}");
580 assert!(matches!(
581 text,
582 MathExpr::Frac {
583 style: FracStyle::Text,
584 ..
585 }
586 ));
587 }
588
589 #[test]
590 fn delimiter_variants() {
591 let paren = parse(r"\left( x \right)");
592 assert!(matches!(
593 paren,
594 MathExpr::Fenced {
595 open: Delimiter::Paren,
596 close: Delimiter::Paren,
597 ..
598 }
599 ));
600
601 let bracket = parse(r"\left[ x \right]");
602 assert!(matches!(
603 bracket,
604 MathExpr::Fenced {
605 open: Delimiter::Bracket,
606 close: Delimiter::Bracket,
607 ..
608 }
609 ));
610
611 let brace = parse(r"\left\{ x \right\}");
612 assert!(matches!(
613 brace,
614 MathExpr::Fenced {
615 open: Delimiter::Brace,
616 close: Delimiter::Brace,
617 ..
618 }
619 ));
620
621 let angle = parse(r"\left\langle x \right\rangle");
622 assert!(matches!(
623 angle,
624 MathExpr::Fenced {
625 open: Delimiter::Angle,
626 close: Delimiter::Angle,
627 ..
628 }
629 ));
630 }
631
632 #[test]
633 fn invisible_delimiter() {
634 let expr = parse(r"\left. x \right|");
635 assert!(matches!(
636 expr,
637 MathExpr::Fenced {
638 open: Delimiter::None,
639 close: Delimiter::Pipe,
640 ..
641 }
642 ));
643 }
644
645 #[test]
646 fn partial_and_nabla() {
647 let partial = parse(r"\partial");
648 assert_eq!(
649 partial,
650 MathExpr::Ident {
651 value: "∂".to_string()
652 }
653 );
654
655 let nabla = parse(r"\nabla");
656 assert_eq!(
657 nabla,
658 MathExpr::Ident {
659 value: "∇".to_string()
660 }
661 );
662 }
663
664 #[test]
665 fn mathrm_produces_font_override() {
666 let expr = parse(r"\mathrm{d}");
667 assert!(
669 matches!(
670 expr,
671 MathExpr::FontOverride {
672 font: MathFont::Roman,
673 ..
674 }
675 ),
676 "expected FontOverride(Roman), got {:?}",
677 expr
678 );
679 }
680
681 #[test]
682 fn tier2_accent_commands() {
683 let accents = [
684 (r"\hat{x}", AccentKind::Hat),
685 (r"\tilde{x}", AccentKind::Tilde),
686 (r"\vec{x}", AccentKind::Vec),
687 (r"\dot{x}", AccentKind::Dot),
688 (r"\ddot{x}", AccentKind::Ddot),
689 (r"\bar{x}", AccentKind::Bar),
690 ];
691 for (input, expected_kind) in accents {
692 let expr = parse(input);
693 assert!(
694 matches!(&expr, MathExpr::Accent { kind, .. } if *kind == expected_kind),
695 "{input} should be Accent({:?}), got {:?}",
696 expected_kind,
697 expr
698 );
699 }
700 }
701
702 #[test]
703 fn tier2_over_under() {
704 let ol = parse(r"\overline{x}");
705 assert!(matches!(ol, MathExpr::Overline { .. }));
706
707 let ul = parse(r"\underline{x}");
708 assert!(matches!(ul, MathExpr::Underline { .. }));
709
710 let ob = parse(r"\overbrace{x}");
711 assert!(matches!(ob, MathExpr::Overbrace { .. }));
712
713 let ub = parse(r"\underbrace{x}");
714 assert!(matches!(ub, MathExpr::Underbrace { .. }));
715 }
716
717 #[test]
718 fn pmatrix_environment() {
719 let expr = parse(r"\begin{pmatrix} a & b \\ c & d \end{pmatrix}");
720 assert!(
721 matches!(
722 expr,
723 MathExpr::Matrix {
724 delimiters: MatrixDelimiters::Paren,
725 ..
726 }
727 ),
728 "expected pmatrix, got {:?}",
729 expr
730 );
731 }
732
733 #[test]
734 fn cases_environment() {
735 let expr = parse(r"\begin{cases} x & x > 0 \\ -x & x \leq 0 \end{cases}");
736 assert!(
737 matches!(expr, MathExpr::Cases { .. }),
738 "expected Cases, got {:?}",
739 expr
740 );
741 }
742
743 #[test]
744 fn align_environment() {
745 let expr = parse(r"\begin{align} x &= 1 \\ y &= 2 \end{align}");
746 assert!(
747 matches!(expr, MathExpr::Align { .. }),
748 "expected Align, got {:?}",
749 expr
750 );
751 }
752
753 #[test]
754 fn unknown_environment_error() {
755 let expr = parse(r"\begin{unknownenv} x \end{unknownenv}");
756 assert!(
757 matches!(expr, MathExpr::Error { .. }),
758 "expected Error, got {:?}",
759 expr
760 );
761 }
762
763 #[test]
764 fn never_panics_on_malformed() {
765 let inputs = [
767 r"\frac{}{", r"\frac{}", r"\sqrt[", r"\left(", r"^{x}", r"\begin{pmatrix}", r"\color{red}", ];
775 for input in inputs {
776 let _ = parse(input);
777 }
778 }
779}