Skip to main content

rdx_math/
lib.rs

1pub mod parser;
2pub mod symbols;
3/// `rdx-math` — LaTeX math parser for the RDX specification.
4///
5/// Parses LaTeX math strings (without `$` delimiters) into structured [`MathExpr`] trees.
6///
7/// # Example
8///
9/// ```rust
10/// use rdx_math::parse;
11///
12/// let expr = parse(r"\frac{a}{b}");
13/// ```
14pub mod tokenizer;
15
16use std::collections::HashMap;
17
18// ─── Re-export rdx-ast types ─────────────────────────────────────────────────
19
20pub use rdx_ast::{
21    AccentKind, AlignRow, CaseRow, ColumnAlign, Delimiter, FracStyle, LimitStyle, MathExpr,
22    MathFont, MathOperator, MathSpace, MathStyle, MatrixDelimiters, OperatorKind, SmashMode,
23};
24
25// ─── Macro definition ─────────────────────────────────────────────────────────
26
27/// A user-defined LaTeX macro.
28pub struct MacroDef {
29    /// Number of arguments (0 for nullary macros).
30    pub arity: u8,
31    /// Template string using `#1`, `#2`, … as argument placeholders.
32    pub template: String,
33}
34
35// ─── Public API ───────────────────────────────────────────────────────────────
36
37/// Parse a LaTeX math string into a [`MathExpr`] tree.
38///
39/// The input must not include the surrounding `$` delimiters.
40///
41/// Any unrecognised constructs are wrapped in [`MathExpr::Error`] nodes rather than
42/// causing a panic.
43pub fn parse(input: &str) -> MathExpr {
44    let tokens = tokenizer::tokenize(input);
45    let mut ts = tokenizer::TokenStream::new(tokens);
46    parser::parse_expr(&mut ts)
47}
48
49/// Parse a LaTeX math string with user-defined macro expansion.
50///
51/// Macros are expanded before parsing. The `macros` map keys must include the backslash
52/// (e.g., `"\\R"`). Each expansion step decrements an internal depth counter; if the
53/// counter reaches zero the expansion is aborted and an [`MathExpr::Error`] is returned
54/// for the affected expression.
55pub fn parse_with_macros(input: &str, macros: &HashMap<String, MacroDef>) -> MathExpr {
56    match expand_macros(input, macros, 64) {
57        Ok(expanded) => parse(&expanded),
58        Err(msg) => MathExpr::Error {
59            raw: input.to_string(),
60            message: msg,
61        },
62    }
63}
64
65// ─── Macro expansion ──────────────────────────────────────────────────────────
66
67/// Expand all macros in `input`, up to `max_depth` recursion levels.
68fn expand_macros(
69    input: &str,
70    macros: &HashMap<String, MacroDef>,
71    max_depth: usize,
72) -> Result<String, String> {
73    if max_depth == 0 {
74        return Err(
75            "macro expansion depth limit (64) exceeded — possible infinite loop".to_string(),
76        );
77    }
78
79    let mut result = String::with_capacity(input.len());
80    let chars: Vec<char> = input.chars().collect();
81    let n = chars.len();
82    let mut i = 0;
83
84    while i < n {
85        if chars[i] != '\\' {
86            result.push(chars[i]);
87            i += 1;
88            continue;
89        }
90
91        // We have a backslash.  Try to match a macro name.
92        let macro_start = i;
93        i += 1; // skip '\'
94
95        if i >= n {
96            result.push('\\');
97            continue;
98        }
99
100        // Collect the command name
101        let name_start = i;
102        if chars[i].is_ascii_alphabetic() {
103            while i < n && chars[i].is_ascii_alphabetic() {
104                i += 1;
105            }
106        } else {
107            // Single non-alpha symbol command
108            i += 1;
109        }
110        let cmd_name: String = chars[name_start..i].iter().collect();
111        let full_name = format!("\\{cmd_name}");
112
113        if let Some(def) = macros.get(&full_name) {
114            // Collect arguments
115            let mut args: Vec<String> = Vec::new();
116            let mut j = i;
117
118            for _ in 0..def.arity {
119                // Skip whitespace
120                while j < n && chars[j].is_ascii_whitespace() {
121                    j += 1;
122                }
123                if j >= n {
124                    break;
125                }
126                if chars[j] == '{' {
127                    // Brace-delimited argument
128                    j += 1; // skip {
129                    let arg_start = j;
130                    let mut depth = 1usize;
131                    while j < n && depth > 0 {
132                        match chars[j] {
133                            '{' => depth += 1,
134                            '}' => depth -= 1,
135                            _ => {}
136                        }
137                        if depth > 0 {
138                            j += 1;
139                        } else {
140                            // closing brace
141                            break;
142                        }
143                    }
144                    let arg: String = chars[arg_start..j].iter().collect();
145                    if j < n && chars[j] == '}' {
146                        j += 1; // skip closing }
147                    }
148                    args.push(arg);
149                } else {
150                    // Single character argument
151                    args.push(chars[j].to_string());
152                    j += 1;
153                }
154            }
155            i = j;
156
157            // Substitute arguments into template
158            let mut expansion = def.template.clone();
159            for (k, arg) in args.iter().enumerate() {
160                let placeholder = format!("#{}", k + 1);
161                expansion = expansion.replace(&placeholder, arg);
162            }
163
164            // Recursively expand the expansion
165            let sub = expand_macros(&expansion, macros, max_depth - 1)?;
166            result.push_str(&sub);
167        } else {
168            // Not a macro — emit verbatim
169            let raw: String = chars[macro_start..i].iter().collect();
170            result.push_str(&raw);
171        }
172    }
173
174    Ok(result)
175}
176
177// ─── Tests ────────────────────────────────────────────────────────────────────
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    #[test]
184    fn simple_fraction() {
185        let expr = parse(r"\frac{a}{b}");
186        assert_eq!(
187            expr,
188            MathExpr::Frac {
189                numerator: Box::new(MathExpr::Ident {
190                    value: "a".to_string()
191                }),
192                denominator: Box::new(MathExpr::Ident {
193                    value: "b".to_string()
194                }),
195                style: FracStyle::Auto,
196            }
197        );
198    }
199
200    #[test]
201    fn superscript() {
202        let expr = parse("x^2");
203        assert_eq!(
204            expr,
205            MathExpr::Superscript {
206                base: Box::new(MathExpr::Ident {
207                    value: "x".to_string()
208                }),
209                script: Box::new(MathExpr::Number {
210                    value: "2".to_string()
211                }),
212            }
213        );
214    }
215
216    #[test]
217    fn subscript_superscript() {
218        let expr = parse("x_i^2");
219        assert_eq!(
220            expr,
221            MathExpr::Subsuperscript {
222                base: Box::new(MathExpr::Ident {
223                    value: "x".to_string()
224                }),
225                sub: Box::new(MathExpr::Ident {
226                    value: "i".to_string()
227                }),
228                sup: Box::new(MathExpr::Number {
229                    value: "2".to_string()
230                }),
231            }
232        );
233    }
234
235    #[test]
236    fn sum_with_limits() {
237        let expr = parse(r"\sum_{i=0}^{n} a_i");
238        assert!(matches!(
239            expr,
240            MathExpr::Row { .. } // Row containing BigOperator and subscript
241        ));
242    }
243
244    #[test]
245    fn nested_fractions() {
246        let expr = parse(r"\frac{\frac{a}{b}}{c}");
247        assert!(matches!(expr, MathExpr::Frac { .. }));
248        if let MathExpr::Frac { numerator, .. } = &expr {
249            assert!(matches!(**numerator, MathExpr::Frac { .. }));
250        }
251    }
252
253    #[test]
254    fn left_right_delimiters() {
255        let expr = parse(r"\left( \frac{a}{b} \right)");
256        assert!(matches!(
257            expr,
258            MathExpr::Fenced {
259                open: Delimiter::Paren,
260                close: Delimiter::Paren,
261                ..
262            }
263        ));
264    }
265
266    #[test]
267    fn sqrt_with_index() {
268        let expr = parse(r"\sqrt[3]{x}");
269        assert!(matches!(expr, MathExpr::Sqrt { index: Some(_), .. }));
270    }
271
272    #[test]
273    fn unknown_command_error_recovery() {
274        let expr = parse(r"\frac{a}{\unknowncmd}");
275        // Should have Frac structure
276        assert!(matches!(expr, MathExpr::Frac { .. }));
277        // Denominator should be an Error node
278        if let MathExpr::Frac { denominator, .. } = expr {
279            assert!(
280                matches!(*denominator, MathExpr::Error { .. }),
281                "expected Error node for unknown command, got {:?}",
282                *denominator
283            );
284        }
285    }
286
287    #[test]
288    fn greek_letters() {
289        let alpha = parse(r"\alpha");
290        assert_eq!(
291            alpha,
292            MathExpr::Ident {
293                value: "α".to_string()
294            }
295        );
296
297        let beta = parse(r"\beta");
298        assert_eq!(
299            beta,
300            MathExpr::Ident {
301                value: "β".to_string()
302            }
303        );
304
305        // Expression: \alpha + \beta  (a Row)
306        let expr = parse(r"\alpha + \beta");
307        assert!(matches!(expr, MathExpr::Row { .. }));
308    }
309
310    #[test]
311    fn text_in_math() {
312        let expr = parse(r"\text{hello world}");
313        assert_eq!(
314            expr,
315            MathExpr::Text {
316                value: "hello world".to_string()
317            }
318        );
319    }
320
321    #[test]
322    fn macro_expansion_nullary() {
323        let mut macros = HashMap::new();
324        macros.insert(
325            "\\R".to_string(),
326            MacroDef {
327                arity: 0,
328                template: "\\mathbb{R}".to_string(),
329            },
330        );
331        let expr = parse_with_macros(r"x \in \R", &macros);
332        // Should parse as Row containing Ident("x"), Operator("∈"), FontOverride(Blackboard, Ident("R"))
333        assert!(matches!(expr, MathExpr::Row { .. }));
334        if let MathExpr::Row { children } = &expr {
335            let last = children.last().unwrap();
336            assert!(
337                matches!(
338                    last,
339                    MathExpr::FontOverride {
340                        font: MathFont::Blackboard,
341                        ..
342                    }
343                ),
344                "expected FontOverride(Blackboard, ...), got {:?}",
345                last
346            );
347        }
348    }
349
350    #[test]
351    fn macro_expansion_with_arg() {
352        let mut macros = HashMap::new();
353        macros.insert(
354            "\\norm".to_string(),
355            MacroDef {
356                arity: 1,
357                template: "\\left\\lVert #1 \\right\\rVert".to_string(),
358            },
359        );
360        let expr = parse_with_macros(r"\norm{x+y}", &macros);
361        // Should expand to \left\lVert x+y \right\rVert → Fenced(DoublePipe, ...)
362        assert!(
363            matches!(expr, MathExpr::Fenced { .. }),
364            "expected Fenced, got {:?}",
365            expr
366        );
367    }
368
369    #[test]
370    fn empty_input() {
371        let expr = parse("");
372        assert_eq!(
373            expr,
374            MathExpr::Row {
375                children: Vec::new()
376            }
377        );
378    }
379
380    #[test]
381    fn spacing_commands() {
382        let thin = parse(r"\,");
383        assert_eq!(thin, MathExpr::Space(MathSpace::Thin));
384
385        let quad = parse(r"\quad");
386        assert_eq!(quad, MathExpr::Space(MathSpace::Quad));
387
388        // Combined: a \, b \quad c  → Row
389        let expr = parse(r"a \, b \quad c");
390        assert!(matches!(expr, MathExpr::Row { .. }));
391    }
392
393    #[test]
394    fn relational_operators() {
395        let leq = parse(r"\leq");
396        assert!(matches!(
397            leq,
398            MathExpr::Operator(MathOperator {
399                kind: OperatorKind::Relation,
400                ..
401            })
402        ));
403
404        let neq = parse(r"\neq");
405        assert!(matches!(
406            neq,
407            MathExpr::Operator(MathOperator {
408                kind: OperatorKind::Relation,
409                ..
410            })
411        ));
412    }
413
414    #[test]
415    fn sum_with_sub_and_sup() {
416        let expr = parse(r"\sum_{i=0}^{n}");
417        assert!(matches!(
418            expr,
419            MathExpr::BigOperator {
420                lower: Some(_),
421                upper: Some(_),
422                ..
423            }
424        ));
425    }
426
427    #[test]
428    fn macro_expansion_depth_limit() {
429        // A self-recursive macro should not infinite loop
430        let mut macros = HashMap::new();
431        macros.insert(
432            "\\bad".to_string(),
433            MacroDef {
434                arity: 0,
435                template: "\\bad".to_string(),
436            },
437        );
438        let expr = parse_with_macros(r"\bad", &macros);
439        // Should produce an Error (depth exceeded), not panic
440        assert!(
441            matches!(expr, MathExpr::Error { .. }),
442            "expected Error for infinite macro, got {:?}",
443            expr
444        );
445    }
446
447    #[test]
448    fn all_greek_lowercase() {
449        let letters = [
450            "alpha",
451            "beta",
452            "gamma",
453            "delta",
454            "epsilon",
455            "varepsilon",
456            "zeta",
457            "eta",
458            "theta",
459            "vartheta",
460            "iota",
461            "kappa",
462            "lambda",
463            "mu",
464            "nu",
465            "xi",
466            "pi",
467            "varpi",
468            "rho",
469            "varrho",
470            "sigma",
471            "varsigma",
472            "tau",
473            "upsilon",
474            "phi",
475            "varphi",
476            "chi",
477            "psi",
478            "omega",
479        ];
480        for name in &letters {
481            let expr = parse(&format!("\\{name}"));
482            assert!(
483                matches!(expr, MathExpr::Ident { .. }),
484                "\\{name} should be Ident, got {:?}",
485                expr
486            );
487        }
488    }
489
490    #[test]
491    fn all_greek_uppercase() {
492        let letters = [
493            "Gamma", "Delta", "Theta", "Lambda", "Xi", "Pi", "Sigma", "Upsilon", "Phi", "Psi",
494            "Omega",
495        ];
496        for name in &letters {
497            let expr = parse(&format!("\\{name}"));
498            assert!(
499                matches!(expr, MathExpr::Ident { .. }),
500                "\\{name} should be Ident, got {:?}",
501                expr
502            );
503        }
504    }
505
506    #[test]
507    fn all_tier1_operators() {
508        let ops = [
509            r"\times",
510            r"\cdot",
511            r"\pm",
512            r"\mp",
513            r"\div",
514            r"\neq",
515            r"\leq",
516            r"\geq",
517            r"\approx",
518            r"\equiv",
519            r"\sim",
520            r"\cong",
521            r"\propto",
522            r"\in",
523            r"\notin",
524            r"\subset",
525            r"\supset",
526            r"\cup",
527            r"\cap",
528            r"\land",
529            r"\lor",
530            r"\neg",
531            r"\implies",
532            r"\iff",
533        ];
534        for op in &ops {
535            let expr = parse(op);
536            assert!(
537                matches!(expr, MathExpr::Operator(_)),
538                "{op} should be Operator, got {:?}",
539                expr
540            );
541        }
542    }
543
544    #[test]
545    fn all_large_operators() {
546        let ops = [
547            r"\sum", r"\prod", r"\int", r"\iint", r"\iiint", r"\oint", r"\bigcup", r"\bigcap",
548        ];
549        for op in &ops {
550            let expr = parse(op);
551            assert!(
552                matches!(expr, MathExpr::BigOperator { .. }),
553                "{op} should be BigOperator, got {:?}",
554                expr
555            );
556        }
557    }
558
559    #[test]
560    fn frac_styles() {
561        let auto = parse(r"\frac{1}{2}");
562        assert!(matches!(
563            auto,
564            MathExpr::Frac {
565                style: FracStyle::Auto,
566                ..
567            }
568        ));
569
570        let display = parse(r"\dfrac{1}{2}");
571        assert!(matches!(
572            display,
573            MathExpr::Frac {
574                style: FracStyle::Display,
575                ..
576            }
577        ));
578
579        let text = parse(r"\tfrac{1}{2}");
580        assert!(matches!(
581            text,
582            MathExpr::Frac {
583                style: FracStyle::Text,
584                ..
585            }
586        ));
587    }
588
589    #[test]
590    fn delimiter_variants() {
591        let paren = parse(r"\left( x \right)");
592        assert!(matches!(
593            paren,
594            MathExpr::Fenced {
595                open: Delimiter::Paren,
596                close: Delimiter::Paren,
597                ..
598            }
599        ));
600
601        let bracket = parse(r"\left[ x \right]");
602        assert!(matches!(
603            bracket,
604            MathExpr::Fenced {
605                open: Delimiter::Bracket,
606                close: Delimiter::Bracket,
607                ..
608            }
609        ));
610
611        let brace = parse(r"\left\{ x \right\}");
612        assert!(matches!(
613            brace,
614            MathExpr::Fenced {
615                open: Delimiter::Brace,
616                close: Delimiter::Brace,
617                ..
618            }
619        ));
620
621        let angle = parse(r"\left\langle x \right\rangle");
622        assert!(matches!(
623            angle,
624            MathExpr::Fenced {
625                open: Delimiter::Angle,
626                close: Delimiter::Angle,
627                ..
628            }
629        ));
630    }
631
632    #[test]
633    fn invisible_delimiter() {
634        let expr = parse(r"\left. x \right|");
635        assert!(matches!(
636            expr,
637            MathExpr::Fenced {
638                open: Delimiter::None,
639                close: Delimiter::Pipe,
640                ..
641            }
642        ));
643    }
644
645    #[test]
646    fn partial_and_nabla() {
647        let partial = parse(r"\partial");
648        assert_eq!(
649            partial,
650            MathExpr::Ident {
651                value: "∂".to_string()
652            }
653        );
654
655        let nabla = parse(r"\nabla");
656        assert_eq!(
657            nabla,
658            MathExpr::Ident {
659                value: "∇".to_string()
660            }
661        );
662    }
663
664    #[test]
665    fn mathrm_produces_font_override() {
666        let expr = parse(r"\mathrm{d}");
667        // \mathrm is font Roman, but for a single letter we want FontOverride
668        assert!(
669            matches!(
670                expr,
671                MathExpr::FontOverride {
672                    font: MathFont::Roman,
673                    ..
674                }
675            ),
676            "expected FontOverride(Roman), got {:?}",
677            expr
678        );
679    }
680
681    #[test]
682    fn tier2_accent_commands() {
683        let accents = [
684            (r"\hat{x}", AccentKind::Hat),
685            (r"\tilde{x}", AccentKind::Tilde),
686            (r"\vec{x}", AccentKind::Vec),
687            (r"\dot{x}", AccentKind::Dot),
688            (r"\ddot{x}", AccentKind::Ddot),
689            (r"\bar{x}", AccentKind::Bar),
690        ];
691        for (input, expected_kind) in accents {
692            let expr = parse(input);
693            assert!(
694                matches!(&expr, MathExpr::Accent { kind, .. } if *kind == expected_kind),
695                "{input} should be Accent({:?}), got {:?}",
696                expected_kind,
697                expr
698            );
699        }
700    }
701
702    #[test]
703    fn tier2_over_under() {
704        let ol = parse(r"\overline{x}");
705        assert!(matches!(ol, MathExpr::Overline { .. }));
706
707        let ul = parse(r"\underline{x}");
708        assert!(matches!(ul, MathExpr::Underline { .. }));
709
710        let ob = parse(r"\overbrace{x}");
711        assert!(matches!(ob, MathExpr::Overbrace { .. }));
712
713        let ub = parse(r"\underbrace{x}");
714        assert!(matches!(ub, MathExpr::Underbrace { .. }));
715    }
716
717    #[test]
718    fn pmatrix_environment() {
719        let expr = parse(r"\begin{pmatrix} a & b \\ c & d \end{pmatrix}");
720        assert!(
721            matches!(
722                expr,
723                MathExpr::Matrix {
724                    delimiters: MatrixDelimiters::Paren,
725                    ..
726                }
727            ),
728            "expected pmatrix, got {:?}",
729            expr
730        );
731    }
732
733    #[test]
734    fn cases_environment() {
735        let expr = parse(r"\begin{cases} x & x > 0 \\ -x & x \leq 0 \end{cases}");
736        assert!(
737            matches!(expr, MathExpr::Cases { .. }),
738            "expected Cases, got {:?}",
739            expr
740        );
741    }
742
743    #[test]
744    fn align_environment() {
745        let expr = parse(r"\begin{align} x &= 1 \\ y &= 2 \end{align}");
746        assert!(
747            matches!(expr, MathExpr::Align { .. }),
748            "expected Align, got {:?}",
749            expr
750        );
751    }
752
753    #[test]
754    fn unknown_environment_error() {
755        let expr = parse(r"\begin{unknownenv} x \end{unknownenv}");
756        assert!(
757            matches!(expr, MathExpr::Error { .. }),
758            "expected Error, got {:?}",
759            expr
760        );
761    }
762
763    #[test]
764    fn never_panics_on_malformed() {
765        // These should all produce results without panicking
766        let inputs = [
767            r"\frac{}{",        // missing }
768            r"\frac{}",         // missing second arg
769            r"\sqrt[",          // unclosed optional
770            r"\left(",          // missing \right
771            r"^{x}",            // dangling script
772            r"\begin{pmatrix}", // missing \end
773            r"\color{red}",     // missing body
774        ];
775        for input in inputs {
776            let _ = parse(input);
777        }
778    }
779}