Skip to main content

ratex_parser/
parser.rs

1use ratex_lexer::token::{SourceLocation, Token};
2use unicode_normalization::UnicodeNormalization;
3
4use crate::error::{ParseError, ParseResult};
5use crate::functions::{self, ArgType, FunctionContext, FUNCTIONS};
6use crate::macro_expander::{MacroExpander, IMPLICIT_COMMANDS};
7use crate::parse_node::{AtomFamily, Mode, ParseNode};
8
9/// End-of-expression tokens.
10static END_OF_EXPRESSION: &[&str] = &["}", "\\endgroup", "\\end", "\\right", "&"];
11
12/// The LaTeX parser. Converts a token stream into a ParseNode AST.
13///
14/// Follows KaTeX's Parser.ts closely:
15/// - `parse()` → parse full expression
16/// - `parseExpression()` → parse a list of atoms
17/// - `parseAtom()` → parse one atom with optional super/subscripts
18/// - `parseGroup()` → parse a group (braced or single token)
19/// - `parseFunction()` → parse a function call with arguments
20/// - `parseSymbol()` → parse a single symbol
21pub struct Parser<'a> {
22    pub mode: Mode,
23    pub gullet: MacroExpander<'a>,
24    pub leftright_depth: i32,
25    next_token: Option<Token>,
26}
27
28impl<'a> Parser<'a> {
29    pub fn new(input: &'a str) -> Self {
30        Self {
31            mode: Mode::Math,
32            gullet: MacroExpander::new(input, Mode::Math),
33            leftright_depth: 0,
34            next_token: None,
35        }
36    }
37
38    // ── Token management ────────────────────────────────────────────────
39
40    /// Return the current lookahead token (fetching from gullet if needed).
41    pub fn fetch(&mut self) -> ParseResult<Token> {
42        if self.next_token.is_none() {
43            self.next_token = Some(self.gullet.expand_next_token()?);
44        }
45        Ok(self.next_token.clone().unwrap())
46    }
47
48    /// Discard the current lookahead token.
49    pub fn consume(&mut self) {
50        self.next_token = None;
51    }
52
53    /// Expect the next token to have the given text, consuming it.
54    pub fn expect(&mut self, text: &str, do_consume: bool) -> ParseResult<()> {
55        let tok = self.fetch()?;
56        if tok.text != text {
57            return Err(ParseError::new(
58                format!("Expected '{}', got '{}'", text, tok.text),
59                Some(&tok),
60            ));
61        }
62        if do_consume {
63            self.consume();
64        }
65        Ok(())
66    }
67
68    /// Consume spaces in math mode.
69    pub fn consume_spaces(&mut self) -> ParseResult<()> {
70        loop {
71            let tok = self.fetch()?;
72            if tok.text == " " {
73                self.consume();
74            } else {
75                break;
76            }
77        }
78        Ok(())
79    }
80
81    /// Switch between "math" and "text" modes.
82    pub fn switch_mode(&mut self, new_mode: Mode) {
83        self.mode = new_mode;
84        self.gullet.switch_mode(new_mode);
85    }
86
87    // ── Main parse entry ────────────────────────────────────────────────
88
89    /// Parse the entire input and return the AST.
90    pub fn parse(&mut self) -> ParseResult<Vec<ParseNode>> {
91        self.gullet.begin_group();
92
93        let result = self.parse_expression(false, None);
94
95        match result {
96            Ok(parse) => {
97                self.expect("EOF", true)?;
98                self.gullet.end_group();
99                Ok(parse)
100            }
101            Err(e) => {
102                self.gullet.end_groups();
103                Err(e)
104            }
105        }
106    }
107
108    // ── Expression parsing ──────────────────────────────────────────────
109
110    /// Parse an expression: a list of atoms.
111    pub fn parse_expression(
112        &mut self,
113        break_on_infix: bool,
114        break_on_token_text: Option<&str>,
115    ) -> ParseResult<Vec<ParseNode>> {
116        let mut body = Vec::new();
117
118        loop {
119            if self.mode == Mode::Math {
120                self.consume_spaces()?;
121            }
122
123            let lex = self.fetch()?;
124
125            if END_OF_EXPRESSION.contains(&lex.text.as_str()) {
126                break;
127            }
128            if let Some(break_text) = break_on_token_text {
129                if lex.text == break_text {
130                    break;
131                }
132            }
133            if break_on_infix {
134                if let Some(func) = FUNCTIONS.get(lex.text.as_str()) {
135                    if func.infix {
136                        break;
137                    }
138                }
139            }
140
141            let atom = self.parse_atom(break_on_token_text)?;
142
143            match atom {
144                None => break,
145                Some(node) if node.type_name() == "internal" => continue,
146                Some(node) => body.push(node),
147            }
148        }
149
150        if self.mode == Mode::Text {
151            self.form_ligatures(&mut body);
152        }
153
154        self.handle_infix_nodes(body)
155    }
156
157    /// Rewrite infix operators (e.g. \over → \frac).
158    fn handle_infix_nodes(&mut self, body: Vec<ParseNode>) -> ParseResult<Vec<ParseNode>> {
159        let mut over_index: Option<usize> = None;
160        let mut func_name: Option<String> = None;
161
162        for (i, node) in body.iter().enumerate() {
163            if let ParseNode::Infix { replace_with, .. } = node {
164                if over_index.is_some() {
165                    return Err(ParseError::msg("only one infix operator per group"));
166                }
167                over_index = Some(i);
168                func_name = Some(replace_with.clone());
169            }
170        }
171
172        if let (Some(idx), Some(fname)) = (over_index, func_name) {
173            let numer_body: Vec<ParseNode> = body[..idx].to_vec();
174            let denom_body: Vec<ParseNode> = body[idx + 1..].to_vec();
175
176            let numer = if numer_body.len() == 1 {
177                if let ParseNode::OrdGroup { .. } = &numer_body[0] {
178                    numer_body.into_iter().next().unwrap()
179                } else {
180                    ParseNode::OrdGroup {
181                        mode: self.mode,
182                        body: numer_body,
183                        semisimple: None,
184                        loc: None,
185                    }
186                }
187            } else {
188                ParseNode::OrdGroup {
189                    mode: self.mode,
190                    body: numer_body,
191                    semisimple: None,
192                    loc: None,
193                }
194            };
195
196            let denom = if denom_body.len() == 1 {
197                if let ParseNode::OrdGroup { .. } = &denom_body[0] {
198                    denom_body.into_iter().next().unwrap()
199                } else {
200                    ParseNode::OrdGroup {
201                        mode: self.mode,
202                        body: denom_body,
203                        semisimple: None,
204                        loc: None,
205                    }
206                }
207            } else {
208                ParseNode::OrdGroup {
209                    mode: self.mode,
210                    body: denom_body,
211                    semisimple: None,
212                    loc: None,
213                }
214            };
215
216            let node = if fname == "\\\\abovefrac" {
217                // \above passes the infix node (with bar size) as the middle argument
218                let infix_node = body[idx].clone();
219                self.call_function(&fname, vec![numer, infix_node, denom], vec![], None, None)?
220            } else {
221                self.call_function(&fname, vec![numer, denom], vec![], None, None)?
222            };
223            Ok(vec![node])
224        } else {
225            Ok(body)
226        }
227    }
228
229    /// Form ligatures in text mode (e.g. -- → –, --- → —).
230    fn form_ligatures(&self, group: &mut Vec<ParseNode>) {
231        let mut i = 0;
232        while i + 1 < group.len() {
233            let a_text = group[i].symbol_text().map(|s| s.to_string());
234            let b_text = group[i + 1].symbol_text().map(|s| s.to_string());
235
236            if let (Some(a), Some(b)) = (a_text, b_text) {
237                if group[i].type_name() == "textord" && group[i + 1].type_name() == "textord" {
238                    if a == "-" && b == "-" {
239                        if i + 2 < group.len() {
240                            if let Some(c) = group[i + 2].symbol_text() {
241                                if c == "-" && group[i + 2].type_name() == "textord" {
242                                    group[i] = ParseNode::TextOrd {
243                                        mode: Mode::Text,
244                                        text: "---".to_string(),
245                                        loc: None,
246                                    };
247                                    group.remove(i + 2);
248                                    group.remove(i + 1);
249                                    continue;
250                                }
251                            }
252                        }
253                        group[i] = ParseNode::TextOrd {
254                            mode: Mode::Text,
255                            text: "--".to_string(),
256                            loc: None,
257                        };
258                        group.remove(i + 1);
259                        continue;
260                    }
261                    if (a == "'" || a == "`") && b == a {
262                        group[i] = ParseNode::TextOrd {
263                            mode: Mode::Text,
264                            text: format!("{}{}", a, a),
265                            loc: None,
266                        };
267                        group.remove(i + 1);
268                        continue;
269                    }
270                }
271            }
272            i += 1;
273        }
274    }
275
276    // ── Atom parsing ────────────────────────────────────────────────────
277
278    /// Parse a single atom with optional super/subscripts.
279    pub fn parse_atom(
280        &mut self,
281        break_on_token_text: Option<&str>,
282    ) -> ParseResult<Option<ParseNode>> {
283        let mut base = self.parse_group("atom", break_on_token_text)?;
284
285        if let Some(ref b) = base {
286            if b.type_name() == "internal" {
287                return Ok(base);
288            }
289        }
290
291        if self.mode == Mode::Text {
292            return Ok(base);
293        }
294
295        let mut superscript: Option<ParseNode> = None;
296        let mut subscript: Option<ParseNode> = None;
297
298        loop {
299            self.consume_spaces()?;
300            let lex = self.fetch()?;
301
302            if lex.text == "\\limits" || lex.text == "\\nolimits" {
303                let is_limits = lex.text == "\\limits";
304                self.consume();
305                if let Some(
306                    ParseNode::Op { limits, .. }
307                    | ParseNode::OperatorName { limits, .. },
308                ) = base.as_mut()
309                {
310                    *limits = is_limits;
311                }
312            } else if lex.text == "^" {
313                if superscript.is_some() {
314                    return Err(ParseError::new("Double superscript", Some(&lex)));
315                }
316                superscript = Some(self.handle_sup_subscript("superscript")?);
317            } else if lex.text == "_" {
318                if subscript.is_some() {
319                    return Err(ParseError::new("Double subscript", Some(&lex)));
320                }
321                subscript = Some(self.handle_sup_subscript("subscript")?);
322            } else if lex.text == "'" {
323                if superscript.is_some() {
324                    return Err(ParseError::new("Double superscript", Some(&lex)));
325                }
326                let prime = ParseNode::TextOrd {
327                    mode: self.mode,
328                    text: "\\prime".to_string(),
329                    loc: None,
330                };
331                let mut primes = vec![prime.clone()];
332                self.consume();
333                while self.fetch()?.text == "'" {
334                    primes.push(prime.clone());
335                    self.consume();
336                }
337                if self.fetch()?.text == "^" {
338                    primes.push(self.handle_sup_subscript("superscript")?);
339                }
340                superscript = Some(ParseNode::OrdGroup {
341                    mode: self.mode,
342                    body: primes,
343                    semisimple: None,
344                    loc: None,
345                });
346            } else if let Some((mapped, is_sub)) = lex
347                .text
348                .chars()
349                .next()
350                .and_then(crate::unicode_sup_sub::unicode_sub_sup)
351            {
352                if is_sub && subscript.is_some() {
353                    return Err(ParseError::new("Double subscript", Some(&lex)));
354                }
355                if !is_sub && superscript.is_some() {
356                    return Err(ParseError::new("Double superscript", Some(&lex)));
357                }
358                // Collect consecutive Unicode sup/sub chars of the same kind
359                let mut subsup_tokens = vec![Token::new(mapped, 0, 0)];
360                self.consume();
361                loop {
362                    let tok = self.fetch()?;
363                    match tok
364                        .text
365                        .chars()
366                        .next()
367                        .and_then(crate::unicode_sup_sub::unicode_sub_sup)
368                    {
369                        Some((m, sub)) if sub == is_sub => {
370                            subsup_tokens.insert(0, Token::new(m, 0, 0));
371                            self.consume();
372                        }
373                        _ => break,
374                    }
375                }
376                let body = self.subparse(subsup_tokens)?;
377                let group = ParseNode::OrdGroup {
378                    mode: Mode::Math,
379                    body,
380                    semisimple: None,
381                    loc: None,
382                };
383                if is_sub {
384                    subscript = Some(group);
385                } else {
386                    superscript = Some(group);
387                }
388            } else {
389                break;
390            }
391        }
392
393        if superscript.is_some() || subscript.is_some() {
394            Ok(Some(ParseNode::SupSub {
395                mode: self.mode,
396                base: base.map(Box::new),
397                sup: superscript.map(Box::new),
398                sub: subscript.map(Box::new),
399                loc: None,
400            }))
401        } else {
402            Ok(base)
403        }
404    }
405
406    /// Handle a subscript or superscript.
407    fn handle_sup_subscript(&mut self, name: &str) -> ParseResult<ParseNode> {
408        let symbol_token = self.fetch()?;
409        self.consume();
410        self.consume_spaces()?;
411
412        let group = self.parse_group(name, None)?;
413        match group {
414            Some(g) if g.type_name() != "internal" => Ok(g),
415            Some(_) => {
416                // Skip internal nodes, try again
417                let g2 = self.parse_group(name, None)?;
418                g2.ok_or_else(|| {
419                    ParseError::new(
420                        format!("Expected group after '{}'", symbol_token.text),
421                        Some(&symbol_token),
422                    )
423                })
424            }
425            None => Err(ParseError::new(
426                format!("Expected group after '{}'", symbol_token.text),
427                Some(&symbol_token),
428            )),
429        }
430    }
431
432    // ── Group parsing ───────────────────────────────────────────────────
433
434    /// Parse a group: braced expression, function call, or single symbol.
435    pub fn parse_group(
436        &mut self,
437        name: &str,
438        break_on_token_text: Option<&str>,
439    ) -> ParseResult<Option<ParseNode>> {
440        let first_token = self.fetch()?;
441        let text = first_token.text.clone();
442
443        if text == "{" || text == "\\begingroup" {
444            self.consume();
445            let group_end = if text == "{" { "}" } else { "\\endgroup" };
446
447            self.gullet.begin_group();
448            let expression = self.parse_expression(false, Some(group_end))?;
449            let last_token = self.fetch()?;
450            self.expect(group_end, true)?;
451            self.gullet.end_group();
452
453            let loc = Some(SourceLocation::range(&first_token.loc, &last_token.loc));
454            let semisimple = if text == "\\begingroup" {
455                Some(true)
456            } else {
457                None
458            };
459
460            Ok(Some(ParseNode::OrdGroup {
461                mode: self.mode,
462                body: expression,
463                semisimple,
464                loc,
465            }))
466        } else {
467            let result = self
468                .parse_function(break_on_token_text, Some(name))?
469                .or_else(|| self.parse_symbol_inner().ok().flatten());
470
471            if result.is_none()
472                && text.starts_with('\\')
473                && !IMPLICIT_COMMANDS.contains(&text.as_str())
474            {
475                return Err(ParseError::new(
476                    format!("Undefined control sequence: {}", text),
477                    Some(&first_token),
478                ));
479            }
480
481            Ok(result)
482        }
483    }
484
485    // ── Function parsing ────────────────────────────────────────────────
486
487    /// Try to parse a function call. Returns None if not a function.
488    pub fn parse_function(
489        &mut self,
490        break_on_token_text: Option<&str>,
491        name: Option<&str>,
492    ) -> ParseResult<Option<ParseNode>> {
493        let token = self.fetch()?;
494        let func = token.text.clone();
495
496        let func_data = match FUNCTIONS.get(func.as_str()) {
497            Some(f) => f,
498            None => return Ok(None),
499        };
500
501        self.consume();
502
503        if let Some(n) = name {
504            if n != "atom" && !func_data.allowed_in_argument {
505                return Err(ParseError::new(
506                    format!("Got function '{}' with no arguments as {}", func, n),
507                    Some(&token),
508                ));
509            }
510        }
511
512        functions::check_mode_compatibility(func_data, self.mode, &func, Some(&token))?;
513
514        // `\hspace*{len}` — `*` is a separate token (not part of the control word); consume it here.
515        // Must use gullet peek/pop only: `parser.fetch()` without `consume()` advances the lexer and
516        // leaves `{` only in `next_token`, so `parse_size_group`'s `gullet.future()` would miss the brace.
517        if func == "\\hspace" {
518            self.gullet.consume_spaces();
519            if self.gullet.future().text == "*" {
520                self.gullet.pop_token();
521            }
522        }
523
524        let (args, opt_args) = self.parse_arguments(&func, func_data)?;
525
526        self.call_function(
527            &func,
528            args,
529            opt_args,
530            Some(token),
531            break_on_token_text.map(|s| s.to_string()).as_deref(),
532        )
533        .map(Some)
534    }
535
536    /// Call a function handler.
537    pub fn call_function(
538        &mut self,
539        name: &str,
540        args: Vec<ParseNode>,
541        opt_args: Vec<Option<ParseNode>>,
542        token: Option<Token>,
543        break_on_token_text: Option<&str>,
544    ) -> ParseResult<ParseNode> {
545        let func = FUNCTIONS.get(name).ok_or_else(|| {
546            ParseError::msg(format!("No function handler for {}", name))
547        })?;
548
549        let mut ctx = FunctionContext {
550            func_name: name.to_string(),
551            parser: self,
552            token: token.clone(),
553            break_on_token_text: break_on_token_text.map(|s| s.to_string()),
554        };
555
556        (func.handler)(&mut ctx, args, opt_args)
557    }
558
559    /// Parse the arguments for a function.
560    pub fn parse_arguments(
561        &mut self,
562        func: &str,
563        func_data: &functions::FunctionSpec,
564    ) -> ParseResult<(Vec<ParseNode>, Vec<Option<ParseNode>>)> {
565        let total_args = func_data.num_args + func_data.num_optional_args;
566        if total_args == 0 {
567            return Ok((Vec::new(), Vec::new()));
568        }
569
570        let mut args = Vec::new();
571        let mut opt_args = Vec::new();
572
573        for i in 0..total_args {
574            let arg_type = func_data
575                .arg_types
576                .as_ref()
577                .and_then(|types| types.get(i).copied());
578            let is_optional = i < func_data.num_optional_args;
579
580            let effective_type = if (func_data.primitive && arg_type.is_none())
581                || (func_data.node_type == "sqrt" && i == 1
582                    && opt_args.first().is_some_and(|o: &Option<ParseNode>| o.is_none()))
583            {
584                Some(ArgType::Primitive)
585            } else {
586                arg_type
587            };
588
589            let arg = self.parse_group_of_type(
590                &format!("argument to '{}'", func),
591                effective_type,
592                is_optional,
593            )?;
594
595            if is_optional {
596                opt_args.push(arg);
597            } else if let Some(a) = arg {
598                args.push(a);
599            } else {
600                return Err(ParseError::msg("Null argument, please report this as a bug"));
601            }
602        }
603
604        Ok((args, opt_args))
605    }
606
607    /// Parse a group with a specific type.
608    fn parse_group_of_type(
609        &mut self,
610        name: &str,
611        arg_type: Option<ArgType>,
612        optional: bool,
613    ) -> ParseResult<Option<ParseNode>> {
614        match arg_type {
615            Some(ArgType::Color) => self.parse_color_group(optional),
616            Some(ArgType::Size) => self.parse_size_group(optional),
617            Some(ArgType::Primitive) => {
618                if optional {
619                    return Err(ParseError::msg("A primitive argument cannot be optional"));
620                }
621                let group = self.parse_group(name, None)?;
622                match group {
623                    Some(g) => Ok(Some(g)),
624                    None => Err(ParseError::new(
625                        format!("Expected group as {}", name),
626                        None,
627                    )),
628                }
629            }
630            Some(ArgType::Math) | Some(ArgType::Text) => {
631                let mode = match arg_type {
632                    Some(ArgType::Math) => Some(Mode::Math),
633                    Some(ArgType::Text) => Some(Mode::Text),
634                    _ => None,
635                };
636                self.parse_argument_group(optional, mode)
637            }
638            Some(ArgType::HBox) => {
639                let group = self.parse_argument_group(optional, Some(Mode::Text))?;
640                match group {
641                    Some(g) => Ok(Some(ParseNode::Styling {
642                        mode: g.mode(),
643                        style: crate::parse_node::StyleStr::Text,
644                        body: vec![g],
645                        loc: None,
646                    })),
647                    None => Ok(None),
648                }
649            }
650            Some(ArgType::Raw) => {
651                let token = self.parse_string_group("raw", optional)?;
652                match token {
653                    Some(t) => Ok(Some(ParseNode::Raw {
654                        mode: Mode::Text,
655                        string: t.text,
656                        loc: None,
657                    })),
658                    None => Ok(None),
659                }
660            }
661            Some(ArgType::Url) => self.parse_url_group(optional),
662            None | Some(ArgType::Original) => self.parse_argument_group(optional, None),
663        }
664    }
665
666    /// Parse a color group.
667    fn parse_color_group(&mut self, optional: bool) -> ParseResult<Option<ParseNode>> {
668        let res = self.parse_string_group("color", optional)?;
669        match res {
670            None => Ok(None),
671            Some(token) => {
672                let text = token.text.trim().to_string();
673                let re = regex_lite::Regex::new(
674                    r"^(#[a-fA-F0-9]{3,4}|#[a-fA-F0-9]{6}|#[a-fA-F0-9]{8}|[a-fA-F0-9]{6}|[a-zA-Z]+|\d+(\.\d+)?(,\d+(\.\d+)?)*)$",
675                )
676                .unwrap();
677
678                if !re.is_match(&text) {
679                    return Err(ParseError::new(
680                        format!("Invalid color: '{}'", text),
681                        Some(&token),
682                    ));
683                }
684                let mut color = text;
685                if regex_lite::Regex::new(r"^[0-9a-fA-F]{6}$")
686                    .unwrap()
687                    .is_match(&color)
688                {
689                    color = format!("#{}", color);
690                }
691
692                Ok(Some(ParseNode::ColorToken {
693                    mode: self.mode,
694                    color,
695                    loc: None,
696                }))
697            }
698        }
699    }
700
701    /// Parse a size group (e.g., "3pt", "1em").
702    pub fn parse_size_group(&mut self, optional: bool) -> ParseResult<Option<ParseNode>> {
703        let mut is_blank = false;
704
705        self.gullet.consume_spaces();
706        let res = if !optional && self.gullet.future().text != "{" {
707            Some(self.parse_regex_group(
708                &regex_lite::Regex::new(r"^[-+]? *(?:$|\d+|\d+\.\d*|\.\d*) *[a-z]{0,2} *$")
709                    .unwrap(),
710                "size",
711            )?)
712        } else {
713            self.parse_string_group("size", optional)?
714        };
715
716        let res = match res {
717            Some(r) => r,
718            None => return Ok(None),
719        };
720
721        let mut text = res.text.clone();
722        if !optional && text.is_empty() {
723            text = "0pt".to_string();
724            is_blank = true;
725        }
726
727        let size_re =
728            regex_lite::Regex::new(r"([-+]?) *(\d+(?:\.\d*)?|\.\d+) *([a-z]{2})").unwrap();
729        let m = size_re.captures(&text).ok_or_else(|| {
730            ParseError::new(format!("Invalid size: '{}'", text), Some(&res))
731        })?;
732
733        let sign = m.get(1).map_or("", |m| m.as_str());
734        let magnitude = m.get(2).map_or("", |m| m.as_str());
735        let unit = m.get(3).map_or("", |m| m.as_str());
736
737        let number: f64 = format!("{}{}", sign, magnitude).parse().unwrap_or(0.0);
738
739        if !is_valid_unit(unit) {
740            return Err(ParseError::new(
741                format!("Invalid unit: '{}'", unit),
742                Some(&res),
743            ));
744        }
745
746        Ok(Some(ParseNode::Size {
747            mode: self.mode,
748            value: crate::parse_node::Measurement {
749                number,
750                unit: unit.to_string(),
751            },
752            is_blank,
753            loc: None,
754        }))
755    }
756
757    /// Parse a URL group.
758    /// Temporarily disables `%` as comment character to allow `%20` etc. in URLs.
759    fn parse_url_group(&mut self, optional: bool) -> ParseResult<Option<ParseNode>> {
760        self.gullet.lexer.set_catcode('%', 13);
761        self.gullet.lexer.set_catcode('~', 12);
762        let res = self.parse_string_group("url", optional);
763        self.gullet.lexer.set_catcode('%', 14);
764        self.gullet.lexer.set_catcode('~', 13);
765        let res = res?;
766        match res {
767            None => Ok(None),
768            Some(token) => {
769                let url = token.text;
770                Ok(Some(ParseNode::Url {
771                    mode: self.mode,
772                    url,
773                    loc: None,
774                }))
775            }
776        }
777    }
778
779    /// Parse a string group (brace-enclosed string).
780    fn parse_string_group(
781        &mut self,
782        _mode_name: &str,
783        optional: bool,
784    ) -> ParseResult<Option<Token>> {
785        let arg_token = self.gullet.scan_argument(optional)?;
786        let arg_token = match arg_token {
787            Some(t) => t,
788            None => return Ok(None),
789        };
790
791        let mut s = String::new();
792        loop {
793            let next = self.fetch()?;
794            if next.text == "EOF" {
795                break;
796            }
797            s.push_str(&next.text);
798            self.consume();
799        }
800        self.consume(); // consume EOF
801
802        let mut result = arg_token;
803        result.text = s;
804        Ok(Some(result))
805    }
806
807    /// Parse a regex-delimited group.
808    fn parse_regex_group(
809        &mut self,
810        regex: &regex_lite::Regex,
811        mode_name: &str,
812    ) -> ParseResult<Token> {
813        let first_token = self.fetch()?;
814        let mut last_token = first_token.clone();
815        let mut s = String::new();
816
817        loop {
818            let next = self.fetch()?;
819            if next.text == "EOF" {
820                break;
821            }
822            let candidate = format!("{}{}", s, next.text);
823            if regex.is_match(&candidate) {
824                last_token = next;
825                s = candidate;
826                self.consume();
827            } else {
828                break;
829            }
830        }
831
832        if s.is_empty() {
833            return Err(ParseError::new(
834                format!("Invalid {}: '{}'", mode_name, first_token.text),
835                Some(&first_token),
836            ));
837        }
838
839        Ok(first_token.range(&last_token, s))
840    }
841
842    /// Parse an argument group (with optional mode switch).
843    pub fn parse_argument_group(
844        &mut self,
845        optional: bool,
846        mode: Option<Mode>,
847    ) -> ParseResult<Option<ParseNode>> {
848        let arg_token = self.gullet.scan_argument(optional)?;
849        let arg_token = match arg_token {
850            Some(t) => t,
851            None => return Ok(None),
852        };
853
854        let outer_mode = self.mode;
855        if let Some(m) = mode {
856            self.switch_mode(m);
857        }
858
859        self.gullet.begin_group();
860        let expression = self.parse_expression(false, Some("EOF"))?;
861        self.expect("EOF", true)?;
862        self.gullet.end_group();
863
864        let result = ParseNode::OrdGroup {
865            mode: self.mode,
866            loc: Some(arg_token.loc.clone()),
867            body: expression,
868            semisimple: None,
869        };
870
871        if mode.is_some() {
872            self.switch_mode(outer_mode);
873        }
874
875        Ok(Some(result))
876    }
877
878    // ── Symbol parsing ──────────────────────────────────────────────────
879
880    /// Parse a single symbol (internal version that returns Result).
881    fn parse_symbol_inner(&mut self) -> ParseResult<Option<ParseNode>> {
882        let nucleus = self.fetch()?;
883        let text = nucleus.text.clone();
884
885        if let Some(stripped) = text.strip_prefix("\\verb") {
886            self.consume();
887            let arg = stripped.to_string();
888            let star = arg.starts_with('*');
889            let arg = if star { &arg[1..] } else { &arg };
890
891            if arg.len() < 2 {
892                return Err(ParseError::new("\\verb assertion failed", Some(&nucleus)));
893            }
894            let body = arg[1..arg.len() - 1].to_string();
895            return Ok(Some(ParseNode::Verb {
896                mode: Mode::Text,
897                body,
898                star,
899                loc: Some(nucleus.loc.clone()),
900            }));
901        }
902
903        let font_mode = match self.mode {
904            Mode::Math => ratex_font::symbols::Mode::Math,
905            Mode::Text => ratex_font::symbols::Mode::Text,
906        };
907
908        // ^ and _ are handled by parse_atom for sup/sub, not as symbol nodes
909        if text == "^" || text == "_" {
910            return Ok(None);
911        }
912
913        // Bare backslash (incomplete control sequence) → not a valid symbol
914        if text == "\\" {
915            return Ok(None);
916        }
917
918        if let Some(sym_info) = ratex_font::symbols::get_symbol(&text, font_mode) {
919            let loc = Some(SourceLocation::range(&nucleus.loc, &nucleus.loc));
920            let group = sym_info.group;
921
922            let node = if group.is_atom() {
923                let family = match group {
924                    ratex_font::symbols::Group::Bin => AtomFamily::Bin,
925                    ratex_font::symbols::Group::Close => AtomFamily::Close,
926                    ratex_font::symbols::Group::Inner => AtomFamily::Inner,
927                    ratex_font::symbols::Group::Open => AtomFamily::Open,
928                    ratex_font::symbols::Group::Punct => AtomFamily::Punct,
929                    ratex_font::symbols::Group::Rel => AtomFamily::Rel,
930                    _ => unreachable!(),
931                };
932                ParseNode::Atom {
933                    mode: self.mode,
934                    family,
935                    text: text.clone(),
936                    loc,
937                }
938            } else {
939                match group {
940                    ratex_font::symbols::Group::MathOrd => ParseNode::MathOrd {
941                        mode: self.mode,
942                        text: text.clone(),
943                        loc,
944                    },
945                    ratex_font::symbols::Group::TextOrd => ParseNode::TextOrd {
946                        mode: self.mode,
947                        text: text.clone(),
948                        loc,
949                    },
950                    ratex_font::symbols::Group::OpToken => ParseNode::OpToken {
951                        mode: self.mode,
952                        text: text.clone(),
953                        loc,
954                    },
955                    ratex_font::symbols::Group::AccentToken => ParseNode::AccentToken {
956                        mode: self.mode,
957                        text: text.clone(),
958                        loc,
959                    },
960                    ratex_font::symbols::Group::Spacing => ParseNode::SpacingNode {
961                        mode: self.mode,
962                        text: text.clone(),
963                        loc,
964                    },
965                    _ => ParseNode::MathOrd {
966                        mode: self.mode,
967                        text: text.clone(),
968                        loc,
969                    },
970                }
971            };
972
973            self.consume();
974            return Ok(Some(node));
975        }
976
977        // Unicode accented characters → decompose into accent nodes
978        // Handles both precomposed (á U+00E1) and combining forms (a + U+0301)
979        if let Some(node) = self.try_parse_unicode_accent(&text, &nucleus)? {
980            self.consume();
981            return Ok(Some(node));
982        }
983
984        // Non-ASCII characters without accent decomposition → treat as textord
985        // KaTeX always uses mode="text" for these, regardless of current mode
986        let first_char = text.chars().next();
987        if let Some(ch) = first_char {
988            if ch as u32 >= 0x80 {
989                let node = ParseNode::TextOrd {
990                    mode: Mode::Text,
991                    text: text.clone(),
992                    loc: Some(SourceLocation::range(&nucleus.loc, &nucleus.loc)),
993                };
994                self.consume();
995                return Ok(Some(node));
996            }
997        }
998
999        Ok(None)
1000    }
1001
1002    /// Try to decompose a Unicode accented character into accent nodes.
1003    /// Returns None if no decomposition is available.
1004    /// Only decomposes Latin-script characters, matching KaTeX behavior.
1005    fn try_parse_unicode_accent(
1006        &self,
1007        text: &str,
1008        nucleus: &Token,
1009    ) -> ParseResult<Option<ParseNode>> {
1010        let nfd: String = text.nfd().collect();
1011        let chars: Vec<char> = nfd.chars().collect();
1012
1013        if chars.len() < 2 {
1014            return Ok(None);
1015        }
1016
1017        // Build from the base up through each combining mark
1018        let mut split_idx = chars.len() - 1;
1019        while split_idx > 0 && is_supported_combining_accent(chars[split_idx]) {
1020            split_idx -= 1;
1021        }
1022
1023        // Verify ALL trailing chars are supported combining accents
1024        if split_idx == chars.len() - 1 {
1025            return Ok(None);
1026        }
1027
1028        // Only decompose Latin-script base characters
1029        let base_char = chars[0];
1030        if !is_latin_base_char(base_char) {
1031            return Ok(None);
1032        }
1033
1034        let loc = Some(SourceLocation::range(&nucleus.loc, &nucleus.loc));
1035
1036        // Base: everything before the combining marks
1037        let mut base_str: String = chars[..split_idx + 1].iter().collect();
1038
1039        // Accented i→ı and j→ȷ (dotless variants), matching KaTeX behavior
1040        if base_str.len() == 1 {
1041            match base_str.as_str() {
1042                "i" => base_str = "\u{0131}".to_string(), // ı
1043                "j" => base_str = "\u{0237}".to_string(), // ȷ
1044                _ => {}
1045            }
1046        }
1047
1048        let font_mode = match self.mode {
1049            Mode::Math => ratex_font::symbols::Mode::Math,
1050            Mode::Text => ratex_font::symbols::Mode::Text,
1051        };
1052
1053        let mut node = if base_str.chars().count() == 1 {
1054            let ch = base_str.chars().next().unwrap();
1055            if let Some(sym) = ratex_font::symbols::get_symbol(&base_str, font_mode) {
1056                match sym.group {
1057                    ratex_font::symbols::Group::TextOrd => ParseNode::TextOrd {
1058                        mode: self.mode,
1059                        text: base_str.clone(),
1060                        loc: loc.clone(),
1061                    },
1062                    _ => ParseNode::MathOrd {
1063                        mode: self.mode,
1064                        text: base_str.clone(),
1065                        loc: loc.clone(),
1066                    },
1067                }
1068            } else if (ch as u32) >= 0x80 {
1069                // Non-ASCII base chars always text mode (KaTeX compat)
1070                ParseNode::TextOrd {
1071                    mode: Mode::Text,
1072                    text: base_str.clone(),
1073                    loc: loc.clone(),
1074                }
1075            } else {
1076                ParseNode::MathOrd {
1077                    mode: self.mode,
1078                    text: base_str.clone(),
1079                    loc: loc.clone(),
1080                }
1081            }
1082        } else {
1083            return self.try_parse_unicode_accent(&base_str, nucleus).map(|opt| {
1084                opt.or_else(|| {
1085                    Some(ParseNode::TextOrd {
1086                        mode: Mode::Text,
1087                        text: base_str.clone(),
1088                        loc: loc.clone(),
1089                    })
1090                })
1091            });
1092        };
1093
1094        // Wrap in accent nodes from innermost to outermost
1095        for &combining in &chars[split_idx + 1..] {
1096            let label = combining_to_accent_label(combining, self.mode);
1097            node = ParseNode::Accent {
1098                mode: self.mode,
1099                label,
1100                is_stretchy: Some(false),
1101                is_shifty: Some(true),
1102                base: Box::new(node),
1103                loc: loc.clone(),
1104            };
1105        }
1106
1107        Ok(Some(node))
1108    }
1109
1110    /// Parse a sub-expression from the given tokens.
1111    pub fn subparse(&mut self, tokens: Vec<Token>) -> ParseResult<Vec<ParseNode>> {
1112        let old_token = self.next_token.take();
1113
1114        self.gullet
1115            .push_token(Token::new("}", 0, 0));
1116        self.gullet.push_tokens(tokens);
1117        let parse = self.parse_expression(false, None)?;
1118        self.expect("}", true)?;
1119
1120        self.next_token = old_token;
1121        Ok(parse)
1122    }
1123}
1124
1125fn is_latin_base_char(ch: char) -> bool {
1126    matches!(ch,
1127        'A'..='Z' | 'a'..='z'
1128        | '\u{0131}' // ı (dotless i)
1129        | '\u{0237}' // ȷ (dotless j)
1130        | '\u{00C6}' // Æ
1131        | '\u{00D0}' // Ð
1132        | '\u{00D8}' // Ø
1133        | '\u{00DE}' // Þ
1134        | '\u{00DF}' // ß
1135        | '\u{00E6}' // æ
1136        | '\u{00F0}' // ð
1137        | '\u{00F8}' // ø
1138        | '\u{00FE}' // þ
1139    )
1140}
1141
1142fn is_supported_combining_accent(ch: char) -> bool {
1143    matches!(
1144        ch,
1145        '\u{0300}' | '\u{0301}' | '\u{0302}' | '\u{0303}' | '\u{0304}'
1146        | '\u{0306}' | '\u{0307}' | '\u{0308}' | '\u{030A}' | '\u{030B}' | '\u{030C}'
1147        | '\u{0327}'
1148    )
1149}
1150
1151fn combining_to_accent_label(ch: char, mode: Mode) -> String {
1152    match mode {
1153        Mode::Math => match ch {
1154            '\u{0300}' => "\\grave".to_string(),
1155            '\u{0301}' => "\\acute".to_string(),
1156            '\u{0302}' => "\\hat".to_string(),
1157            '\u{0303}' => "\\tilde".to_string(),
1158            '\u{0304}' => "\\bar".to_string(),
1159            '\u{0306}' => "\\breve".to_string(),
1160            '\u{0307}' => "\\dot".to_string(),
1161            '\u{0308}' => "\\ddot".to_string(),
1162            '\u{030A}' => "\\mathring".to_string(),
1163            '\u{030B}' => "\\H".to_string(),
1164            '\u{030C}' => "\\check".to_string(),
1165            '\u{0327}' => "\\c".to_string(),
1166            _ => format!("\\char\"{:X}", ch as u32),
1167        },
1168        Mode::Text => match ch {
1169            '\u{0300}' => "\\`".to_string(),
1170            '\u{0301}' => "\\'".to_string(),
1171            '\u{0302}' => "\\^".to_string(),
1172            '\u{0303}' => "\\~".to_string(),
1173            '\u{0304}' => "\\=".to_string(),
1174            '\u{0306}' => "\\u".to_string(),
1175            '\u{0307}' => "\\.".to_string(),
1176            '\u{0308}' => "\\\"".to_string(),
1177            '\u{030A}' => "\\r".to_string(),
1178            '\u{030B}' => "\\H".to_string(),
1179            '\u{030C}' => "\\v".to_string(),
1180            '\u{0327}' => "\\c".to_string(),
1181            _ => format!("\\char\"{:X}", ch as u32),
1182        },
1183    }
1184}
1185
1186fn is_valid_unit(unit: &str) -> bool {
1187    matches!(
1188        unit,
1189        "pt" | "mm" | "cm" | "in" | "bp" | "pc" | "dd" | "cc" | "nd" | "nc" | "sp" | "px"
1190            | "ex" | "em" | "mu"
1191    )
1192}
1193
1194/// If the whole expression is wrapped in TeX inline/display math delimiters, parse the inside only.
1195/// The parser already runs in math mode; a leading `$` would otherwise hit the `$` / `\\(` "switch to math"
1196/// handler, which is disallowed in math mode (see `functions::math`).
1197fn strip_outer_math_delimiters(input: &str) -> &str {
1198    let s = input.trim();
1199    if s.len() >= 4 && s.starts_with("$$") && s.ends_with("$$") {
1200        return s[2..s.len() - 2].trim();
1201    }
1202    if s.len() >= 2 && s.starts_with('$') && s.ends_with('$') {
1203        return s[1..s.len() - 1].trim();
1204    }
1205    s
1206}
1207
1208/// Convenience function: parse a LaTeX string and return the AST.
1209pub fn parse(input: &str) -> ParseResult<Vec<ParseNode>> {
1210    Parser::new(strip_outer_math_delimiters(input)).parse()
1211}
1212
1213#[cfg(test)]
1214mod tests {
1215    use super::*;
1216
1217    #[test]
1218    fn test_parse_single_char() {
1219        let result = parse("x").unwrap();
1220        assert_eq!(result.len(), 1);
1221        assert_eq!(result[0].type_name(), "mathord");
1222    }
1223
1224    #[test]
1225    fn test_parse_strips_outer_dollar_inline_math() {
1226        let inner = r"C_p[\ce{H2O(l)}] = \pu{75.3 J // mol K}";
1227        let wrapped = format!("${inner}$");
1228        let a = parse(&wrapped).expect("wrapped");
1229        let b = parse(inner).expect("inner");
1230        assert_eq!(a.len(), b.len());
1231        for (x, y) in a.iter().zip(b.iter()) {
1232            assert_eq!(x.type_name(), y.type_name());
1233        }
1234    }
1235
1236    #[test]
1237    fn test_parse_addition() {
1238        let result = parse("a+b").unwrap();
1239        assert_eq!(result.len(), 3);
1240        assert_eq!(result[0].type_name(), "mathord"); // a
1241        assert_eq!(result[1].type_name(), "atom"); // +
1242        assert_eq!(result[2].type_name(), "mathord"); // b
1243    }
1244
1245    #[test]
1246    fn test_parse_superscript() {
1247        let result = parse("x^2").unwrap();
1248        assert_eq!(result.len(), 1);
1249        assert_eq!(result[0].type_name(), "supsub");
1250    }
1251
1252    #[test]
1253    fn test_parse_subscript() {
1254        let result = parse("a_i").unwrap();
1255        assert_eq!(result.len(), 1);
1256        assert_eq!(result[0].type_name(), "supsub");
1257    }
1258
1259    #[test]
1260    fn test_parse_supsub() {
1261        let result = parse("x^2_i").unwrap();
1262        assert_eq!(result.len(), 1);
1263        assert_eq!(result[0].type_name(), "supsub");
1264        if let ParseNode::SupSub { sup, sub, .. } = &result[0] {
1265            assert!(sup.is_some());
1266            assert!(sub.is_some());
1267        } else {
1268            panic!("Expected SupSub");
1269        }
1270    }
1271
1272    #[test]
1273    fn test_parse_group() {
1274        let result = parse("{a+b}").unwrap();
1275        assert_eq!(result.len(), 1);
1276        assert_eq!(result[0].type_name(), "ordgroup");
1277    }
1278
1279    #[test]
1280    fn test_parse_frac() {
1281        let result = parse("\\frac{a}{b}").unwrap();
1282        assert_eq!(result.len(), 1);
1283        assert_eq!(result[0].type_name(), "genfrac");
1284    }
1285
1286    #[test]
1287    fn test_parse_sqrt() {
1288        let result = parse("\\sqrt{x}").unwrap();
1289        assert_eq!(result.len(), 1);
1290        assert_eq!(result[0].type_name(), "sqrt");
1291    }
1292
1293    #[test]
1294    fn test_parse_sqrt_optional() {
1295        let result = parse("\\sqrt[3]{x}").unwrap();
1296        assert_eq!(result.len(), 1);
1297        if let ParseNode::Sqrt { index, .. } = &result[0] {
1298            assert!(index.is_some());
1299        } else {
1300            panic!("Expected Sqrt");
1301        }
1302    }
1303
1304    #[test]
1305    fn test_parse_nested() {
1306        let result = parse("\\frac{\\sqrt{a^2+b^2}}{c}").unwrap();
1307        assert_eq!(result.len(), 1);
1308        assert_eq!(result[0].type_name(), "genfrac");
1309    }
1310
1311    #[test]
1312    fn test_parse_empty() {
1313        let result = parse("").unwrap();
1314        assert_eq!(result.len(), 0);
1315    }
1316
1317    #[test]
1318    fn test_parse_double_superscript_error() {
1319        let result = parse("x^2^3");
1320        assert!(result.is_err());
1321    }
1322
1323    #[test]
1324    fn test_parse_unclosed_brace_error() {
1325        let result = parse("{x");
1326        assert!(result.is_err());
1327    }
1328
1329    #[test]
1330    fn test_parse_json_output() {
1331        let result = parse("x^2").unwrap();
1332        let json = serde_json::to_string_pretty(&result).unwrap();
1333        assert!(json.contains("supsub"));
1334    }
1335}