Skip to main content

ratex_parser/
parser.rs

1use ratex_lexer::token::{SourceLocation, Token};
2use unicode_normalization::UnicodeNormalization;
3
4use crate::error::{ParseError, ParseResult};
5use crate::functions::{self, ArgType, FunctionContext, FUNCTIONS};
6use crate::macro_expander::{MacroExpander, IMPLICIT_COMMANDS};
7use crate::parse_node::{AtomFamily, Mode, ParseNode};
8
9/// End-of-expression tokens.
10static END_OF_EXPRESSION: &[&str] = &["}", "\\endgroup", "\\end", "\\right", "&"];
11
12/// The LaTeX parser. Converts a token stream into a ParseNode AST.
13///
14/// Follows KaTeX's Parser.ts closely:
15/// - `parse()` → parse full expression
16/// - `parseExpression()` → parse a list of atoms
17/// - `parseAtom()` → parse one atom with optional super/subscripts
18/// - `parseGroup()` → parse a group (braced or single token)
19/// - `parseFunction()` → parse a function call with arguments
20/// - `parseSymbol()` → parse a single symbol
21pub struct Parser<'a> {
22    pub mode: Mode,
23    pub gullet: MacroExpander<'a>,
24    pub leftright_depth: i32,
25    next_token: Option<Token>,
26}
27
28impl<'a> Parser<'a> {
29    pub fn new(input: &'a str) -> Self {
30        Self {
31            mode: Mode::Math,
32            gullet: MacroExpander::new(input, Mode::Math),
33            leftright_depth: 0,
34            next_token: None,
35        }
36    }
37
38    // ── Token management ────────────────────────────────────────────────
39
40    /// Return the current lookahead token (fetching from gullet if needed).
41    pub fn fetch(&mut self) -> ParseResult<Token> {
42        if self.next_token.is_none() {
43            self.next_token = Some(self.gullet.expand_next_token()?);
44        }
45        Ok(self.next_token.clone().unwrap())
46    }
47
48    /// Discard the current lookahead token.
49    pub fn consume(&mut self) {
50        self.next_token = None;
51    }
52
53    /// Expect the next token to have the given text, consuming it.
54    pub fn expect(&mut self, text: &str, do_consume: bool) -> ParseResult<()> {
55        let tok = self.fetch()?;
56        if tok.text != text {
57            return Err(ParseError::new(
58                format!("Expected '{}', got '{}'", text, tok.text),
59                Some(&tok),
60            ));
61        }
62        if do_consume {
63            self.consume();
64        }
65        Ok(())
66    }
67
68    /// Consume spaces in math mode.
69    pub fn consume_spaces(&mut self) -> ParseResult<()> {
70        loop {
71            let tok = self.fetch()?;
72            if tok.text == " " {
73                self.consume();
74            } else {
75                break;
76            }
77        }
78        Ok(())
79    }
80
81    /// Switch between "math" and "text" modes.
82    pub fn switch_mode(&mut self, new_mode: Mode) {
83        self.mode = new_mode;
84        self.gullet.switch_mode(new_mode);
85    }
86
87    // ── Main parse entry ────────────────────────────────────────────────
88
89    /// Parse the entire input and return the AST.
90    pub fn parse(&mut self) -> ParseResult<Vec<ParseNode>> {
91        self.gullet.begin_group();
92
93        let result = self.parse_expression(false, None);
94
95        match result {
96            Ok(parse) => {
97                self.expect("EOF", true)?;
98                self.gullet.end_group();
99                Ok(parse)
100            }
101            Err(e) => {
102                self.gullet.end_groups();
103                Err(e)
104            }
105        }
106    }
107
108    // ── Expression parsing ──────────────────────────────────────────────
109
110    /// Parse an expression: a list of atoms.
111    pub fn parse_expression(
112        &mut self,
113        break_on_infix: bool,
114        break_on_token_text: Option<&str>,
115    ) -> ParseResult<Vec<ParseNode>> {
116        let mut body = Vec::new();
117
118        loop {
119            if self.mode == Mode::Math {
120                self.consume_spaces()?;
121            }
122
123            let lex = self.fetch()?;
124
125            if END_OF_EXPRESSION.contains(&lex.text.as_str()) {
126                break;
127            }
128            if let Some(break_text) = break_on_token_text {
129                if lex.text == break_text {
130                    break;
131                }
132            }
133            if break_on_infix {
134                if let Some(func) = FUNCTIONS.get(lex.text.as_str()) {
135                    if func.infix {
136                        break;
137                    }
138                }
139            }
140
141            let atom = self.parse_atom(break_on_token_text)?;
142
143            match atom {
144                None => break,
145                Some(node) if node.type_name() == "internal" => continue,
146                Some(node) => body.push(node),
147            }
148        }
149
150        if self.mode == Mode::Text {
151            self.form_ligatures(&mut body);
152        }
153
154        self.handle_infix_nodes(body)
155    }
156
157    /// Rewrite infix operators (e.g. \over → \frac).
158    fn handle_infix_nodes(&mut self, body: Vec<ParseNode>) -> ParseResult<Vec<ParseNode>> {
159        let mut over_index: Option<usize> = None;
160        let mut func_name: Option<String> = None;
161
162        for (i, node) in body.iter().enumerate() {
163            if let ParseNode::Infix { replace_with, .. } = node {
164                if over_index.is_some() {
165                    return Err(ParseError::msg("only one infix operator per group"));
166                }
167                over_index = Some(i);
168                func_name = Some(replace_with.clone());
169            }
170        }
171
172        if let (Some(idx), Some(fname)) = (over_index, func_name) {
173            let numer_body: Vec<ParseNode> = body[..idx].to_vec();
174            let denom_body: Vec<ParseNode> = body[idx + 1..].to_vec();
175
176            let numer = if numer_body.len() == 1 {
177                if let ParseNode::OrdGroup { .. } = &numer_body[0] {
178                    numer_body.into_iter().next().unwrap()
179                } else {
180                    ParseNode::OrdGroup {
181                        mode: self.mode,
182                        body: numer_body,
183                        semisimple: None,
184                        loc: None,
185                    }
186                }
187            } else {
188                ParseNode::OrdGroup {
189                    mode: self.mode,
190                    body: numer_body,
191                    semisimple: None,
192                    loc: None,
193                }
194            };
195
196            let denom = if denom_body.len() == 1 {
197                if let ParseNode::OrdGroup { .. } = &denom_body[0] {
198                    denom_body.into_iter().next().unwrap()
199                } else {
200                    ParseNode::OrdGroup {
201                        mode: self.mode,
202                        body: denom_body,
203                        semisimple: None,
204                        loc: None,
205                    }
206                }
207            } else {
208                ParseNode::OrdGroup {
209                    mode: self.mode,
210                    body: denom_body,
211                    semisimple: None,
212                    loc: None,
213                }
214            };
215
216            let node = if fname == "\\\\abovefrac" {
217                // \above passes the infix node (with bar size) as the middle argument
218                let infix_node = body[idx].clone();
219                self.call_function(&fname, vec![numer, infix_node, denom], vec![], None, None)?
220            } else {
221                self.call_function(&fname, vec![numer, denom], vec![], None, None)?
222            };
223            Ok(vec![node])
224        } else {
225            Ok(body)
226        }
227    }
228
229    /// Form ligatures in text mode (e.g. -- → –, --- → —).
230    fn form_ligatures(&self, group: &mut Vec<ParseNode>) {
231        let mut i = 0;
232        while i + 1 < group.len() {
233            let a_text = group[i].symbol_text().map(|s| s.to_string());
234            let b_text = group[i + 1].symbol_text().map(|s| s.to_string());
235
236            if let (Some(a), Some(b)) = (a_text, b_text) {
237                if group[i].type_name() == "textord" && group[i + 1].type_name() == "textord" {
238                    if a == "-" && b == "-" {
239                        if i + 2 < group.len() {
240                            if let Some(c) = group[i + 2].symbol_text() {
241                                if c == "-" && group[i + 2].type_name() == "textord" {
242                                    group[i] = ParseNode::TextOrd {
243                                        mode: Mode::Text,
244                                        text: "---".to_string(),
245                                        loc: None,
246                                    };
247                                    group.remove(i + 2);
248                                    group.remove(i + 1);
249                                    continue;
250                                }
251                            }
252                        }
253                        group[i] = ParseNode::TextOrd {
254                            mode: Mode::Text,
255                            text: "--".to_string(),
256                            loc: None,
257                        };
258                        group.remove(i + 1);
259                        continue;
260                    }
261                    if (a == "'" || a == "`") && b == a {
262                        group[i] = ParseNode::TextOrd {
263                            mode: Mode::Text,
264                            text: format!("{}{}", a, a),
265                            loc: None,
266                        };
267                        group.remove(i + 1);
268                        continue;
269                    }
270                }
271            }
272            i += 1;
273        }
274    }
275
276    // ── Atom parsing ────────────────────────────────────────────────────
277
278    /// Parse a single atom with optional super/subscripts.
279    pub fn parse_atom(
280        &mut self,
281        break_on_token_text: Option<&str>,
282    ) -> ParseResult<Option<ParseNode>> {
283        let mut base = self.parse_group("atom", break_on_token_text)?;
284
285        if let Some(ref b) = base {
286            if b.type_name() == "internal" {
287                return Ok(base);
288            }
289        }
290
291        if self.mode == Mode::Text {
292            return Ok(base);
293        }
294
295        let mut superscript: Option<ParseNode> = None;
296        let mut subscript: Option<ParseNode> = None;
297
298        loop {
299            self.consume_spaces()?;
300            let lex = self.fetch()?;
301
302            if lex.text == "\\limits" || lex.text == "\\nolimits" {
303                let is_limits = lex.text == "\\limits";
304                self.consume();
305                if let Some(
306                    ParseNode::Op { limits, .. }
307                    | ParseNode::OperatorName { limits, .. },
308                ) = base.as_mut()
309                {
310                    *limits = is_limits;
311                }
312            } else if lex.text == "^" {
313                if superscript.is_some() {
314                    return Err(ParseError::new("Double superscript", Some(&lex)));
315                }
316                superscript = Some(self.handle_sup_subscript("superscript")?);
317            } else if lex.text == "_" {
318                if subscript.is_some() {
319                    return Err(ParseError::new("Double subscript", Some(&lex)));
320                }
321                subscript = Some(self.handle_sup_subscript("subscript")?);
322            } else if lex.text == "'" {
323                if superscript.is_some() {
324                    return Err(ParseError::new("Double superscript", Some(&lex)));
325                }
326                let prime = ParseNode::TextOrd {
327                    mode: self.mode,
328                    text: "\\prime".to_string(),
329                    loc: None,
330                };
331                let mut primes = vec![prime.clone()];
332                self.consume();
333                while self.fetch()?.text == "'" {
334                    primes.push(prime.clone());
335                    self.consume();
336                }
337                if self.fetch()?.text == "^" {
338                    primes.push(self.handle_sup_subscript("superscript")?);
339                }
340                superscript = Some(ParseNode::OrdGroup {
341                    mode: self.mode,
342                    body: primes,
343                    semisimple: None,
344                    loc: None,
345                });
346            } else if let Some((mapped, is_sub)) = lex
347                .text
348                .chars()
349                .next()
350                .and_then(crate::unicode_sup_sub::unicode_sub_sup)
351            {
352                if is_sub && subscript.is_some() {
353                    return Err(ParseError::new("Double subscript", Some(&lex)));
354                }
355                if !is_sub && superscript.is_some() {
356                    return Err(ParseError::new("Double superscript", Some(&lex)));
357                }
358                // Collect consecutive Unicode sup/sub chars of the same kind
359                let mut subsup_tokens = vec![Token::new(mapped, 0, 0)];
360                self.consume();
361                loop {
362                    let tok = self.fetch()?;
363                    match tok
364                        .text
365                        .chars()
366                        .next()
367                        .and_then(crate::unicode_sup_sub::unicode_sub_sup)
368                    {
369                        Some((m, sub)) if sub == is_sub => {
370                            subsup_tokens.insert(0, Token::new(m, 0, 0));
371                            self.consume();
372                        }
373                        _ => break,
374                    }
375                }
376                let body = self.subparse(subsup_tokens)?;
377                let group = ParseNode::OrdGroup {
378                    mode: Mode::Math,
379                    body,
380                    semisimple: None,
381                    loc: None,
382                };
383                if is_sub {
384                    subscript = Some(group);
385                } else {
386                    superscript = Some(group);
387                }
388            } else {
389                break;
390            }
391        }
392
393        if superscript.is_some() || subscript.is_some() {
394            Ok(Some(ParseNode::SupSub {
395                mode: self.mode,
396                base: base.map(Box::new),
397                sup: superscript.map(Box::new),
398                sub: subscript.map(Box::new),
399                loc: None,
400            }))
401        } else {
402            Ok(base)
403        }
404    }
405
406    /// Handle a subscript or superscript.
407    fn handle_sup_subscript(&mut self, name: &str) -> ParseResult<ParseNode> {
408        let symbol_token = self.fetch()?;
409        self.consume();
410        self.consume_spaces()?;
411
412        let group = self.parse_group(name, None)?;
413        match group {
414            Some(g) if g.type_name() != "internal" => Ok(g),
415            Some(_) => {
416                // Skip internal nodes, try again
417                let g2 = self.parse_group(name, None)?;
418                g2.ok_or_else(|| {
419                    ParseError::new(
420                        format!("Expected group after '{}'", symbol_token.text),
421                        Some(&symbol_token),
422                    )
423                })
424            }
425            None => Err(ParseError::new(
426                format!("Expected group after '{}'", symbol_token.text),
427                Some(&symbol_token),
428            )),
429        }
430    }
431
432    // ── Group parsing ───────────────────────────────────────────────────
433
434    /// Parse a group: braced expression, function call, or single symbol.
435    pub fn parse_group(
436        &mut self,
437        name: &str,
438        break_on_token_text: Option<&str>,
439    ) -> ParseResult<Option<ParseNode>> {
440        let first_token = self.fetch()?;
441        let text = first_token.text.clone();
442
443        if text == "{" || text == "\\begingroup" {
444            self.consume();
445            let group_end = if text == "{" { "}" } else { "\\endgroup" };
446
447            self.gullet.begin_group();
448            let expression = self.parse_expression(false, Some(group_end))?;
449            let last_token = self.fetch()?;
450            self.expect(group_end, true)?;
451            self.gullet.end_group();
452
453            let loc = Some(SourceLocation::range(&first_token.loc, &last_token.loc));
454            let semisimple = if text == "\\begingroup" {
455                Some(true)
456            } else {
457                None
458            };
459
460            Ok(Some(ParseNode::OrdGroup {
461                mode: self.mode,
462                body: expression,
463                semisimple,
464                loc,
465            }))
466        } else {
467            let result = self
468                .parse_function(break_on_token_text, Some(name))?
469                .or_else(|| self.parse_symbol_inner().ok().flatten());
470
471            if result.is_none()
472                && text.starts_with('\\')
473                && !IMPLICIT_COMMANDS.contains(&text.as_str())
474            {
475                return Err(ParseError::new(
476                    format!("Undefined control sequence: {}", text),
477                    Some(&first_token),
478                ));
479            }
480
481            Ok(result)
482        }
483    }
484
485    // ── Function parsing ────────────────────────────────────────────────
486
487    /// Try to parse a function call. Returns None if not a function.
488    pub fn parse_function(
489        &mut self,
490        break_on_token_text: Option<&str>,
491        name: Option<&str>,
492    ) -> ParseResult<Option<ParseNode>> {
493        let token = self.fetch()?;
494        let func = token.text.clone();
495
496        let func_data = match FUNCTIONS.get(func.as_str()) {
497            Some(f) => f,
498            None => return Ok(None),
499        };
500
501        self.consume();
502
503        if let Some(n) = name {
504            if n != "atom" && !func_data.allowed_in_argument {
505                return Err(ParseError::new(
506                    format!("Got function '{}' with no arguments as {}", func, n),
507                    Some(&token),
508                ));
509            }
510        }
511
512        functions::check_mode_compatibility(func_data, self.mode, &func, Some(&token))?;
513
514        let (args, opt_args) = self.parse_arguments(&func, func_data)?;
515
516        self.call_function(
517            &func,
518            args,
519            opt_args,
520            Some(token),
521            break_on_token_text.map(|s| s.to_string()).as_deref(),
522        )
523        .map(Some)
524    }
525
526    /// Call a function handler.
527    pub fn call_function(
528        &mut self,
529        name: &str,
530        args: Vec<ParseNode>,
531        opt_args: Vec<Option<ParseNode>>,
532        token: Option<Token>,
533        break_on_token_text: Option<&str>,
534    ) -> ParseResult<ParseNode> {
535        let func = FUNCTIONS.get(name).ok_or_else(|| {
536            ParseError::msg(format!("No function handler for {}", name))
537        })?;
538
539        let mut ctx = FunctionContext {
540            func_name: name.to_string(),
541            parser: self,
542            token: token.clone(),
543            break_on_token_text: break_on_token_text.map(|s| s.to_string()),
544        };
545
546        (func.handler)(&mut ctx, args, opt_args)
547    }
548
549    /// Parse the arguments for a function.
550    pub fn parse_arguments(
551        &mut self,
552        func: &str,
553        func_data: &functions::FunctionSpec,
554    ) -> ParseResult<(Vec<ParseNode>, Vec<Option<ParseNode>>)> {
555        let total_args = func_data.num_args + func_data.num_optional_args;
556        if total_args == 0 {
557            return Ok((Vec::new(), Vec::new()));
558        }
559
560        let mut args = Vec::new();
561        let mut opt_args = Vec::new();
562
563        for i in 0..total_args {
564            let arg_type = func_data
565                .arg_types
566                .as_ref()
567                .and_then(|types| types.get(i).copied());
568            let is_optional = i < func_data.num_optional_args;
569
570            let effective_type = if (func_data.primitive && arg_type.is_none())
571                || (func_data.node_type == "sqrt" && i == 1
572                    && opt_args.first().is_some_and(|o: &Option<ParseNode>| o.is_none()))
573            {
574                Some(ArgType::Primitive)
575            } else {
576                arg_type
577            };
578
579            let arg = self.parse_group_of_type(
580                &format!("argument to '{}'", func),
581                effective_type,
582                is_optional,
583            )?;
584
585            if is_optional {
586                opt_args.push(arg);
587            } else if let Some(a) = arg {
588                args.push(a);
589            } else {
590                return Err(ParseError::msg("Null argument, please report this as a bug"));
591            }
592        }
593
594        Ok((args, opt_args))
595    }
596
597    /// Parse a group with a specific type.
598    fn parse_group_of_type(
599        &mut self,
600        name: &str,
601        arg_type: Option<ArgType>,
602        optional: bool,
603    ) -> ParseResult<Option<ParseNode>> {
604        match arg_type {
605            Some(ArgType::Color) => self.parse_color_group(optional),
606            Some(ArgType::Size) => self.parse_size_group(optional),
607            Some(ArgType::Primitive) => {
608                if optional {
609                    return Err(ParseError::msg("A primitive argument cannot be optional"));
610                }
611                let group = self.parse_group(name, None)?;
612                match group {
613                    Some(g) => Ok(Some(g)),
614                    None => Err(ParseError::new(
615                        format!("Expected group as {}", name),
616                        None,
617                    )),
618                }
619            }
620            Some(ArgType::Math) | Some(ArgType::Text) => {
621                let mode = match arg_type {
622                    Some(ArgType::Math) => Some(Mode::Math),
623                    Some(ArgType::Text) => Some(Mode::Text),
624                    _ => None,
625                };
626                self.parse_argument_group(optional, mode)
627            }
628            Some(ArgType::HBox) => {
629                let group = self.parse_argument_group(optional, Some(Mode::Text))?;
630                match group {
631                    Some(g) => Ok(Some(ParseNode::Styling {
632                        mode: g.mode(),
633                        style: crate::parse_node::StyleStr::Text,
634                        body: vec![g],
635                        loc: None,
636                    })),
637                    None => Ok(None),
638                }
639            }
640            Some(ArgType::Raw) => {
641                let token = self.parse_string_group("raw", optional)?;
642                match token {
643                    Some(t) => Ok(Some(ParseNode::Raw {
644                        mode: Mode::Text,
645                        string: t.text,
646                        loc: None,
647                    })),
648                    None => Ok(None),
649                }
650            }
651            Some(ArgType::Url) => self.parse_url_group(optional),
652            None | Some(ArgType::Original) => self.parse_argument_group(optional, None),
653        }
654    }
655
656    /// Parse a color group.
657    fn parse_color_group(&mut self, optional: bool) -> ParseResult<Option<ParseNode>> {
658        let res = self.parse_string_group("color", optional)?;
659        match res {
660            None => Ok(None),
661            Some(token) => {
662                let text = token.text.trim().to_string();
663                let re = regex_lite::Regex::new(
664                    r"^(#[a-fA-F0-9]{3,4}|#[a-fA-F0-9]{6}|#[a-fA-F0-9]{8}|[a-fA-F0-9]{6}|[a-zA-Z]+)$",
665                )
666                .unwrap();
667
668                if !re.is_match(&text) {
669                    return Err(ParseError::new(
670                        format!("Invalid color: '{}'", text),
671                        Some(&token),
672                    ));
673                }
674                let mut color = text;
675                if regex_lite::Regex::new(r"^[0-9a-fA-F]{6}$")
676                    .unwrap()
677                    .is_match(&color)
678                {
679                    color = format!("#{}", color);
680                }
681
682                Ok(Some(ParseNode::ColorToken {
683                    mode: self.mode,
684                    color,
685                    loc: None,
686                }))
687            }
688        }
689    }
690
691    /// Parse a size group (e.g., "3pt", "1em").
692    pub fn parse_size_group(&mut self, optional: bool) -> ParseResult<Option<ParseNode>> {
693        let mut is_blank = false;
694
695        self.gullet.consume_spaces();
696        let res = if !optional && self.gullet.future().text != "{" {
697            Some(self.parse_regex_group(
698                &regex_lite::Regex::new(r"^[-+]? *(?:$|\d+|\d+\.\d*|\.\d*) *[a-z]{0,2} *$")
699                    .unwrap(),
700                "size",
701            )?)
702        } else {
703            self.parse_string_group("size", optional)?
704        };
705
706        let res = match res {
707            Some(r) => r,
708            None => return Ok(None),
709        };
710
711        let mut text = res.text.clone();
712        if !optional && text.is_empty() {
713            text = "0pt".to_string();
714            is_blank = true;
715        }
716
717        let size_re =
718            regex_lite::Regex::new(r"([-+]?) *(\d+(?:\.\d*)?|\.\d+) *([a-z]{2})").unwrap();
719        let m = size_re.captures(&text).ok_or_else(|| {
720            ParseError::new(format!("Invalid size: '{}'", text), Some(&res))
721        })?;
722
723        let sign = m.get(1).map_or("", |m| m.as_str());
724        let magnitude = m.get(2).map_or("", |m| m.as_str());
725        let unit = m.get(3).map_or("", |m| m.as_str());
726
727        let number: f64 = format!("{}{}", sign, magnitude).parse().unwrap_or(0.0);
728
729        if !is_valid_unit(unit) {
730            return Err(ParseError::new(
731                format!("Invalid unit: '{}'", unit),
732                Some(&res),
733            ));
734        }
735
736        Ok(Some(ParseNode::Size {
737            mode: self.mode,
738            value: crate::parse_node::Measurement {
739                number,
740                unit: unit.to_string(),
741            },
742            is_blank,
743            loc: None,
744        }))
745    }
746
747    /// Parse a URL group.
748    /// Temporarily disables `%` as comment character to allow `%20` etc. in URLs.
749    fn parse_url_group(&mut self, optional: bool) -> ParseResult<Option<ParseNode>> {
750        self.gullet.lexer.set_catcode('%', 13);
751        self.gullet.lexer.set_catcode('~', 12);
752        let res = self.parse_string_group("url", optional);
753        self.gullet.lexer.set_catcode('%', 14);
754        self.gullet.lexer.set_catcode('~', 13);
755        let res = res?;
756        match res {
757            None => Ok(None),
758            Some(token) => {
759                let url = token.text;
760                Ok(Some(ParseNode::Url {
761                    mode: self.mode,
762                    url,
763                    loc: None,
764                }))
765            }
766        }
767    }
768
769    /// Parse a string group (brace-enclosed string).
770    fn parse_string_group(
771        &mut self,
772        _mode_name: &str,
773        optional: bool,
774    ) -> ParseResult<Option<Token>> {
775        let arg_token = self.gullet.scan_argument(optional)?;
776        let arg_token = match arg_token {
777            Some(t) => t,
778            None => return Ok(None),
779        };
780
781        let mut s = String::new();
782        loop {
783            let next = self.fetch()?;
784            if next.text == "EOF" {
785                break;
786            }
787            s.push_str(&next.text);
788            self.consume();
789        }
790        self.consume(); // consume EOF
791
792        let mut result = arg_token;
793        result.text = s;
794        Ok(Some(result))
795    }
796
797    /// Parse a regex-delimited group.
798    fn parse_regex_group(
799        &mut self,
800        regex: &regex_lite::Regex,
801        mode_name: &str,
802    ) -> ParseResult<Token> {
803        let first_token = self.fetch()?;
804        let mut last_token = first_token.clone();
805        let mut s = String::new();
806
807        loop {
808            let next = self.fetch()?;
809            if next.text == "EOF" {
810                break;
811            }
812            let candidate = format!("{}{}", s, next.text);
813            if regex.is_match(&candidate) {
814                last_token = next;
815                s = candidate;
816                self.consume();
817            } else {
818                break;
819            }
820        }
821
822        if s.is_empty() {
823            return Err(ParseError::new(
824                format!("Invalid {}: '{}'", mode_name, first_token.text),
825                Some(&first_token),
826            ));
827        }
828
829        Ok(first_token.range(&last_token, s))
830    }
831
832    /// Parse an argument group (with optional mode switch).
833    pub fn parse_argument_group(
834        &mut self,
835        optional: bool,
836        mode: Option<Mode>,
837    ) -> ParseResult<Option<ParseNode>> {
838        let arg_token = self.gullet.scan_argument(optional)?;
839        let arg_token = match arg_token {
840            Some(t) => t,
841            None => return Ok(None),
842        };
843
844        let outer_mode = self.mode;
845        if let Some(m) = mode {
846            self.switch_mode(m);
847        }
848
849        self.gullet.begin_group();
850        let expression = self.parse_expression(false, Some("EOF"))?;
851        self.expect("EOF", true)?;
852        self.gullet.end_group();
853
854        let result = ParseNode::OrdGroup {
855            mode: self.mode,
856            loc: Some(arg_token.loc.clone()),
857            body: expression,
858            semisimple: None,
859        };
860
861        if mode.is_some() {
862            self.switch_mode(outer_mode);
863        }
864
865        Ok(Some(result))
866    }
867
868    // ── Symbol parsing ──────────────────────────────────────────────────
869
870    /// Parse a single symbol (internal version that returns Result).
871    fn parse_symbol_inner(&mut self) -> ParseResult<Option<ParseNode>> {
872        let nucleus = self.fetch()?;
873        let text = nucleus.text.clone();
874
875        if let Some(stripped) = text.strip_prefix("\\verb") {
876            self.consume();
877            let arg = stripped.to_string();
878            let star = arg.starts_with('*');
879            let arg = if star { &arg[1..] } else { &arg };
880
881            if arg.len() < 2 {
882                return Err(ParseError::new("\\verb assertion failed", Some(&nucleus)));
883            }
884            let body = arg[1..arg.len() - 1].to_string();
885            return Ok(Some(ParseNode::Verb {
886                mode: Mode::Text,
887                body,
888                star,
889                loc: Some(nucleus.loc.clone()),
890            }));
891        }
892
893        let font_mode = match self.mode {
894            Mode::Math => ratex_font::symbols::Mode::Math,
895            Mode::Text => ratex_font::symbols::Mode::Text,
896        };
897
898        // ^ and _ are handled by parse_atom for sup/sub, not as symbol nodes
899        if text == "^" || text == "_" {
900            return Ok(None);
901        }
902
903        // Bare backslash (incomplete control sequence) → not a valid symbol
904        if text == "\\" {
905            return Ok(None);
906        }
907
908        if let Some(sym_info) = ratex_font::symbols::get_symbol(&text, font_mode) {
909            let loc = Some(SourceLocation::range(&nucleus.loc, &nucleus.loc));
910            let group = sym_info.group;
911
912            let node = if group.is_atom() {
913                let family = match group {
914                    ratex_font::symbols::Group::Bin => AtomFamily::Bin,
915                    ratex_font::symbols::Group::Close => AtomFamily::Close,
916                    ratex_font::symbols::Group::Inner => AtomFamily::Inner,
917                    ratex_font::symbols::Group::Open => AtomFamily::Open,
918                    ratex_font::symbols::Group::Punct => AtomFamily::Punct,
919                    ratex_font::symbols::Group::Rel => AtomFamily::Rel,
920                    _ => unreachable!(),
921                };
922                ParseNode::Atom {
923                    mode: self.mode,
924                    family,
925                    text: text.clone(),
926                    loc,
927                }
928            } else {
929                match group {
930                    ratex_font::symbols::Group::MathOrd => ParseNode::MathOrd {
931                        mode: self.mode,
932                        text: text.clone(),
933                        loc,
934                    },
935                    ratex_font::symbols::Group::TextOrd => ParseNode::TextOrd {
936                        mode: self.mode,
937                        text: text.clone(),
938                        loc,
939                    },
940                    ratex_font::symbols::Group::OpToken => ParseNode::OpToken {
941                        mode: self.mode,
942                        text: text.clone(),
943                        loc,
944                    },
945                    ratex_font::symbols::Group::AccentToken => ParseNode::AccentToken {
946                        mode: self.mode,
947                        text: text.clone(),
948                        loc,
949                    },
950                    ratex_font::symbols::Group::Spacing => ParseNode::SpacingNode {
951                        mode: self.mode,
952                        text: text.clone(),
953                        loc,
954                    },
955                    _ => ParseNode::MathOrd {
956                        mode: self.mode,
957                        text: text.clone(),
958                        loc,
959                    },
960                }
961            };
962
963            self.consume();
964            return Ok(Some(node));
965        }
966
967        // Unicode accented characters → decompose into accent nodes
968        // Handles both precomposed (á U+00E1) and combining forms (a + U+0301)
969        if let Some(node) = self.try_parse_unicode_accent(&text, &nucleus)? {
970            self.consume();
971            return Ok(Some(node));
972        }
973
974        // Non-ASCII characters without accent decomposition → treat as textord
975        // KaTeX always uses mode="text" for these, regardless of current mode
976        let first_char = text.chars().next();
977        if let Some(ch) = first_char {
978            if ch as u32 >= 0x80 {
979                let node = ParseNode::TextOrd {
980                    mode: Mode::Text,
981                    text: text.clone(),
982                    loc: Some(SourceLocation::range(&nucleus.loc, &nucleus.loc)),
983                };
984                self.consume();
985                return Ok(Some(node));
986            }
987        }
988
989        Ok(None)
990    }
991
992    /// Try to decompose a Unicode accented character into accent nodes.
993    /// Returns None if no decomposition is available.
994    /// Only decomposes Latin-script characters, matching KaTeX behavior.
995    fn try_parse_unicode_accent(
996        &self,
997        text: &str,
998        nucleus: &Token,
999    ) -> ParseResult<Option<ParseNode>> {
1000        let nfd: String = text.nfd().collect();
1001        let chars: Vec<char> = nfd.chars().collect();
1002
1003        if chars.len() < 2 {
1004            return Ok(None);
1005        }
1006
1007        // Build from the base up through each combining mark
1008        let mut split_idx = chars.len() - 1;
1009        while split_idx > 0 && is_supported_combining_accent(chars[split_idx]) {
1010            split_idx -= 1;
1011        }
1012
1013        // Verify ALL trailing chars are supported combining accents
1014        if split_idx == chars.len() - 1 {
1015            return Ok(None);
1016        }
1017
1018        // Only decompose Latin-script base characters
1019        let base_char = chars[0];
1020        if !is_latin_base_char(base_char) {
1021            return Ok(None);
1022        }
1023
1024        let loc = Some(SourceLocation::range(&nucleus.loc, &nucleus.loc));
1025
1026        // Base: everything before the combining marks
1027        let mut base_str: String = chars[..split_idx + 1].iter().collect();
1028
1029        // Accented i→ı and j→ȷ (dotless variants), matching KaTeX behavior
1030        if base_str.len() == 1 {
1031            match base_str.as_str() {
1032                "i" => base_str = "\u{0131}".to_string(), // ı
1033                "j" => base_str = "\u{0237}".to_string(), // ȷ
1034                _ => {}
1035            }
1036        }
1037
1038        let font_mode = match self.mode {
1039            Mode::Math => ratex_font::symbols::Mode::Math,
1040            Mode::Text => ratex_font::symbols::Mode::Text,
1041        };
1042
1043        let mut node = if base_str.chars().count() == 1 {
1044            let ch = base_str.chars().next().unwrap();
1045            if let Some(sym) = ratex_font::symbols::get_symbol(&base_str, font_mode) {
1046                match sym.group {
1047                    ratex_font::symbols::Group::TextOrd => ParseNode::TextOrd {
1048                        mode: self.mode,
1049                        text: base_str.clone(),
1050                        loc: loc.clone(),
1051                    },
1052                    _ => ParseNode::MathOrd {
1053                        mode: self.mode,
1054                        text: base_str.clone(),
1055                        loc: loc.clone(),
1056                    },
1057                }
1058            } else if (ch as u32) >= 0x80 {
1059                // Non-ASCII base chars always text mode (KaTeX compat)
1060                ParseNode::TextOrd {
1061                    mode: Mode::Text,
1062                    text: base_str.clone(),
1063                    loc: loc.clone(),
1064                }
1065            } else {
1066                ParseNode::MathOrd {
1067                    mode: self.mode,
1068                    text: base_str.clone(),
1069                    loc: loc.clone(),
1070                }
1071            }
1072        } else {
1073            return self.try_parse_unicode_accent(&base_str, nucleus).map(|opt| {
1074                opt.or_else(|| {
1075                    Some(ParseNode::TextOrd {
1076                        mode: Mode::Text,
1077                        text: base_str.clone(),
1078                        loc: loc.clone(),
1079                    })
1080                })
1081            });
1082        };
1083
1084        // Wrap in accent nodes from innermost to outermost
1085        for &combining in &chars[split_idx + 1..] {
1086            let label = combining_to_accent_label(combining, self.mode);
1087            node = ParseNode::Accent {
1088                mode: self.mode,
1089                label,
1090                is_stretchy: Some(false),
1091                is_shifty: Some(true),
1092                base: Box::new(node),
1093                loc: loc.clone(),
1094            };
1095        }
1096
1097        Ok(Some(node))
1098    }
1099
1100    /// Parse a sub-expression from the given tokens.
1101    pub fn subparse(&mut self, tokens: Vec<Token>) -> ParseResult<Vec<ParseNode>> {
1102        let old_token = self.next_token.take();
1103
1104        self.gullet
1105            .push_token(Token::new("}", 0, 0));
1106        self.gullet.push_tokens(tokens);
1107        let parse = self.parse_expression(false, None)?;
1108        self.expect("}", true)?;
1109
1110        self.next_token = old_token;
1111        Ok(parse)
1112    }
1113}
1114
1115fn is_latin_base_char(ch: char) -> bool {
1116    matches!(ch,
1117        'A'..='Z' | 'a'..='z'
1118        | '\u{0131}' // ı (dotless i)
1119        | '\u{0237}' // ȷ (dotless j)
1120        | '\u{00C6}' // Æ
1121        | '\u{00D0}' // Ð
1122        | '\u{00D8}' // Ø
1123        | '\u{00DE}' // Þ
1124        | '\u{00DF}' // ß
1125        | '\u{00E6}' // æ
1126        | '\u{00F0}' // ð
1127        | '\u{00F8}' // ø
1128        | '\u{00FE}' // þ
1129    )
1130}
1131
1132fn is_supported_combining_accent(ch: char) -> bool {
1133    matches!(
1134        ch,
1135        '\u{0300}' | '\u{0301}' | '\u{0302}' | '\u{0303}' | '\u{0304}'
1136        | '\u{0306}' | '\u{0307}' | '\u{0308}' | '\u{030A}' | '\u{030B}' | '\u{030C}'
1137        | '\u{0327}'
1138    )
1139}
1140
1141fn combining_to_accent_label(ch: char, mode: Mode) -> String {
1142    match mode {
1143        Mode::Math => match ch {
1144            '\u{0300}' => "\\grave".to_string(),
1145            '\u{0301}' => "\\acute".to_string(),
1146            '\u{0302}' => "\\hat".to_string(),
1147            '\u{0303}' => "\\tilde".to_string(),
1148            '\u{0304}' => "\\bar".to_string(),
1149            '\u{0306}' => "\\breve".to_string(),
1150            '\u{0307}' => "\\dot".to_string(),
1151            '\u{0308}' => "\\ddot".to_string(),
1152            '\u{030A}' => "\\mathring".to_string(),
1153            '\u{030B}' => "\\H".to_string(),
1154            '\u{030C}' => "\\check".to_string(),
1155            '\u{0327}' => "\\c".to_string(),
1156            _ => format!("\\char\"{:X}", ch as u32),
1157        },
1158        Mode::Text => match ch {
1159            '\u{0300}' => "\\`".to_string(),
1160            '\u{0301}' => "\\'".to_string(),
1161            '\u{0302}' => "\\^".to_string(),
1162            '\u{0303}' => "\\~".to_string(),
1163            '\u{0304}' => "\\=".to_string(),
1164            '\u{0306}' => "\\u".to_string(),
1165            '\u{0307}' => "\\.".to_string(),
1166            '\u{0308}' => "\\\"".to_string(),
1167            '\u{030A}' => "\\r".to_string(),
1168            '\u{030B}' => "\\H".to_string(),
1169            '\u{030C}' => "\\v".to_string(),
1170            '\u{0327}' => "\\c".to_string(),
1171            _ => format!("\\char\"{:X}", ch as u32),
1172        },
1173    }
1174}
1175
1176fn is_valid_unit(unit: &str) -> bool {
1177    matches!(
1178        unit,
1179        "pt" | "mm" | "cm" | "in" | "bp" | "pc" | "dd" | "cc" | "nd" | "nc" | "sp" | "px"
1180            | "ex" | "em" | "mu"
1181    )
1182}
1183
1184/// Convenience function: parse a LaTeX string and return the AST.
1185pub fn parse(input: &str) -> ParseResult<Vec<ParseNode>> {
1186    Parser::new(input).parse()
1187}
1188
1189#[cfg(test)]
1190mod tests {
1191    use super::*;
1192
1193    #[test]
1194    fn test_parse_single_char() {
1195        let result = parse("x").unwrap();
1196        assert_eq!(result.len(), 1);
1197        assert_eq!(result[0].type_name(), "mathord");
1198    }
1199
1200    #[test]
1201    fn test_parse_addition() {
1202        let result = parse("a+b").unwrap();
1203        assert_eq!(result.len(), 3);
1204        assert_eq!(result[0].type_name(), "mathord"); // a
1205        assert_eq!(result[1].type_name(), "atom"); // +
1206        assert_eq!(result[2].type_name(), "mathord"); // b
1207    }
1208
1209    #[test]
1210    fn test_parse_superscript() {
1211        let result = parse("x^2").unwrap();
1212        assert_eq!(result.len(), 1);
1213        assert_eq!(result[0].type_name(), "supsub");
1214    }
1215
1216    #[test]
1217    fn test_parse_subscript() {
1218        let result = parse("a_i").unwrap();
1219        assert_eq!(result.len(), 1);
1220        assert_eq!(result[0].type_name(), "supsub");
1221    }
1222
1223    #[test]
1224    fn test_parse_supsub() {
1225        let result = parse("x^2_i").unwrap();
1226        assert_eq!(result.len(), 1);
1227        assert_eq!(result[0].type_name(), "supsub");
1228        if let ParseNode::SupSub { sup, sub, .. } = &result[0] {
1229            assert!(sup.is_some());
1230            assert!(sub.is_some());
1231        } else {
1232            panic!("Expected SupSub");
1233        }
1234    }
1235
1236    #[test]
1237    fn test_parse_group() {
1238        let result = parse("{a+b}").unwrap();
1239        assert_eq!(result.len(), 1);
1240        assert_eq!(result[0].type_name(), "ordgroup");
1241    }
1242
1243    #[test]
1244    fn test_parse_frac() {
1245        let result = parse("\\frac{a}{b}").unwrap();
1246        assert_eq!(result.len(), 1);
1247        assert_eq!(result[0].type_name(), "genfrac");
1248    }
1249
1250    #[test]
1251    fn test_parse_sqrt() {
1252        let result = parse("\\sqrt{x}").unwrap();
1253        assert_eq!(result.len(), 1);
1254        assert_eq!(result[0].type_name(), "sqrt");
1255    }
1256
1257    #[test]
1258    fn test_parse_sqrt_optional() {
1259        let result = parse("\\sqrt[3]{x}").unwrap();
1260        assert_eq!(result.len(), 1);
1261        if let ParseNode::Sqrt { index, .. } = &result[0] {
1262            assert!(index.is_some());
1263        } else {
1264            panic!("Expected Sqrt");
1265        }
1266    }
1267
1268    #[test]
1269    fn test_parse_nested() {
1270        let result = parse("\\frac{\\sqrt{a^2+b^2}}{c}").unwrap();
1271        assert_eq!(result.len(), 1);
1272        assert_eq!(result[0].type_name(), "genfrac");
1273    }
1274
1275    #[test]
1276    fn test_parse_empty() {
1277        let result = parse("").unwrap();
1278        assert_eq!(result.len(), 0);
1279    }
1280
1281    #[test]
1282    fn test_parse_double_superscript_error() {
1283        let result = parse("x^2^3");
1284        assert!(result.is_err());
1285    }
1286
1287    #[test]
1288    fn test_parse_unclosed_brace_error() {
1289        let result = parse("{x");
1290        assert!(result.is_err());
1291    }
1292
1293    #[test]
1294    fn test_parse_json_output() {
1295        let result = parse("x^2").unwrap();
1296        let json = serde_json::to_string_pretty(&result).unwrap();
1297        assert!(json.contains("supsub"));
1298    }
1299}