Skip to main content

ratex_parser/
parser.rs

1use ratex_lexer::token::{SourceLocation, Token};
2use unicode_normalization::UnicodeNormalization;
3
4use crate::error::{ParseError, ParseResult};
5use crate::functions::{self, ArgType, FunctionContext, FUNCTIONS};
6use crate::macro_expander::{MacroExpander, IMPLICIT_COMMANDS};
7use crate::parse_node::{AtomFamily, Mode, ParseNode};
8
9/// End-of-expression tokens.
10static END_OF_EXPRESSION: &[&str] = &["}", "\\endgroup", "\\end", "\\right", "&"];
11
12/// The LaTeX parser. Converts a token stream into a ParseNode AST.
13///
14/// Follows KaTeX's Parser.ts closely:
15/// - `parse()` → parse full expression
16/// - `parseExpression()` → parse a list of atoms
17/// - `parseAtom()` → parse one atom with optional super/subscripts
18/// - `parseGroup()` → parse a group (braced or single token)
19/// - `parseFunction()` → parse a function call with arguments
20/// - `parseSymbol()` → parse a single symbol
21pub struct Parser<'a> {
22    pub mode: Mode,
23    pub gullet: MacroExpander<'a>,
24    pub leftright_depth: i32,
25    next_token: Option<Token>,
26    pub equation_counter: usize,
27}
28
29impl<'a> Parser<'a> {
30    pub fn new(input: &'a str) -> Self {
31        Self {
32            mode: Mode::Math,
33            gullet: MacroExpander::new(input, Mode::Math),
34            leftright_depth: 0,
35            next_token: None,
36            equation_counter: 0,
37        }
38    }
39
40    // ── Token management ────────────────────────────────────────────────
41
42    /// Return the current lookahead token (fetching from gullet if needed).
43    pub fn fetch(&mut self) -> ParseResult<Token> {
44        if self.next_token.is_none() {
45            self.next_token = Some(self.gullet.expand_next_token()?);
46        }
47        Ok(self.next_token.clone().unwrap())
48    }
49
50    /// Discard the current lookahead token.
51    pub fn consume(&mut self) {
52        self.next_token = None;
53    }
54
55    /// Expect the next token to have the given text, consuming it.
56    pub fn expect(&mut self, text: &str, do_consume: bool) -> ParseResult<()> {
57        let tok = self.fetch()?;
58        if tok.text != text {
59            return Err(ParseError::new(
60                format!("Expected '{}', got '{}'", text, tok.text),
61                Some(&tok),
62            ));
63        }
64        if do_consume {
65            self.consume();
66        }
67        Ok(())
68    }
69
70    /// Consume spaces in math mode.
71    pub fn consume_spaces(&mut self) -> ParseResult<()> {
72        loop {
73            let tok = self.fetch()?;
74            if tok.text == " " {
75                self.consume();
76            } else {
77                break;
78            }
79        }
80        Ok(())
81    }
82
83    /// Switch between "math" and "text" modes.
84    pub fn switch_mode(&mut self, new_mode: Mode) {
85        self.mode = new_mode;
86        self.gullet.switch_mode(new_mode);
87    }
88
89    // ── Main parse entry ────────────────────────────────────────────────
90
91    /// Parse the entire input and return the AST.
92    pub fn parse(&mut self) -> ParseResult<Vec<ParseNode>> {
93        self.gullet.begin_group();
94
95        let result = self.parse_expression(false, None);
96
97        match result {
98            Ok(parse) => {
99                self.expect("EOF", true)?;
100                self.gullet.end_group();
101                Ok(parse)
102            }
103            Err(e) => {
104                self.gullet.end_groups();
105                Err(e)
106            }
107        }
108    }
109
110    // ── Expression parsing ──────────────────────────────────────────────
111
112    /// Parse an expression: a list of atoms.
113    pub fn parse_expression(
114        &mut self,
115        break_on_infix: bool,
116        break_on_token_text: Option<&str>,
117    ) -> ParseResult<Vec<ParseNode>> {
118        let mut body = Vec::new();
119
120        loop {
121            if self.mode == Mode::Math {
122                self.consume_spaces()?;
123            }
124
125            let lex = self.fetch()?;
126
127            if END_OF_EXPRESSION.contains(&lex.text.as_str()) {
128                break;
129            }
130            if let Some(break_text) = break_on_token_text {
131                if lex.text == break_text {
132                    break;
133                }
134            }
135            if break_on_infix {
136                if let Some(func) = FUNCTIONS.get(lex.text.as_str()) {
137                    if func.infix {
138                        break;
139                    }
140                }
141            }
142
143            let atom = self.parse_atom(break_on_token_text)?;
144
145            match atom {
146                None => break,
147                Some(node) if node.type_name() == "internal" => continue,
148                Some(node) => body.push(node),
149            }
150        }
151
152        if self.mode == Mode::Text {
153            self.form_ligatures(&mut body);
154        }
155
156        self.handle_infix_nodes(body)
157    }
158
159    /// Rewrite infix operators (e.g. \over → \frac).
160    fn handle_infix_nodes(&mut self, body: Vec<ParseNode>) -> ParseResult<Vec<ParseNode>> {
161        let mut over_index: Option<usize> = None;
162        let mut func_name: Option<String> = None;
163
164        for (i, node) in body.iter().enumerate() {
165            if let ParseNode::Infix { replace_with, .. } = node {
166                if over_index.is_some() {
167                    return Err(ParseError::msg("only one infix operator per group"));
168                }
169                over_index = Some(i);
170                func_name = Some(replace_with.clone());
171            }
172        }
173
174        if let (Some(idx), Some(fname)) = (over_index, func_name) {
175            let numer_body: Vec<ParseNode> = body[..idx].to_vec();
176            let denom_body: Vec<ParseNode> = body[idx + 1..].to_vec();
177
178            let numer = if numer_body.len() == 1 {
179                if let ParseNode::OrdGroup { .. } = &numer_body[0] {
180                    numer_body.into_iter().next().unwrap()
181                } else {
182                    ParseNode::OrdGroup {
183                        mode: self.mode,
184                        body: numer_body,
185                        semisimple: None,
186                        loc: None,
187                    }
188                }
189            } else {
190                ParseNode::OrdGroup {
191                    mode: self.mode,
192                    body: numer_body,
193                    semisimple: None,
194                    loc: None,
195                }
196            };
197
198            let denom = if denom_body.len() == 1 {
199                if let ParseNode::OrdGroup { .. } = &denom_body[0] {
200                    denom_body.into_iter().next().unwrap()
201                } else {
202                    ParseNode::OrdGroup {
203                        mode: self.mode,
204                        body: denom_body,
205                        semisimple: None,
206                        loc: None,
207                    }
208                }
209            } else {
210                ParseNode::OrdGroup {
211                    mode: self.mode,
212                    body: denom_body,
213                    semisimple: None,
214                    loc: None,
215                }
216            };
217
218            let node = if fname == "\\\\abovefrac" {
219                // \above passes the infix node (with bar size) as the middle argument
220                let infix_node = body[idx].clone();
221                self.call_function(&fname, vec![numer, infix_node, denom], vec![], None, None)?
222            } else {
223                self.call_function(&fname, vec![numer, denom], vec![], None, None)?
224            };
225            Ok(vec![node])
226        } else {
227            Ok(body)
228        }
229    }
230
231    /// Form ligatures in text mode (e.g. -- → –, --- → —).
232    fn form_ligatures(&self, group: &mut Vec<ParseNode>) {
233        let mut i = 0;
234        while i + 1 < group.len() {
235            let a_text = group[i].symbol_text().map(|s| s.to_string());
236            let b_text = group[i + 1].symbol_text().map(|s| s.to_string());
237
238            if let (Some(a), Some(b)) = (a_text, b_text) {
239                if group[i].type_name() == "textord" && group[i + 1].type_name() == "textord" {
240                    if a == "-" && b == "-" {
241                        if i + 2 < group.len() {
242                            if let Some(c) = group[i + 2].symbol_text() {
243                                if c == "-" && group[i + 2].type_name() == "textord" {
244                                    group[i] = ParseNode::TextOrd {
245                                        mode: Mode::Text,
246                                        text: "---".to_string(),
247                                        loc: None,
248                                    };
249                                    group.remove(i + 2);
250                                    group.remove(i + 1);
251                                    continue;
252                                }
253                            }
254                        }
255                        group[i] = ParseNode::TextOrd {
256                            mode: Mode::Text,
257                            text: "--".to_string(),
258                            loc: None,
259                        };
260                        group.remove(i + 1);
261                        continue;
262                    }
263                    if (a == "'" || a == "`") && b == a {
264                        group[i] = ParseNode::TextOrd {
265                            mode: Mode::Text,
266                            text: format!("{}{}", a, a),
267                            loc: None,
268                        };
269                        group.remove(i + 1);
270                        continue;
271                    }
272                }
273            }
274            i += 1;
275        }
276    }
277
278    // ── Atom parsing ────────────────────────────────────────────────────
279
280    /// Parse a single atom with optional super/subscripts.
281    pub fn parse_atom(
282        &mut self,
283        break_on_token_text: Option<&str>,
284    ) -> ParseResult<Option<ParseNode>> {
285        let mut base = self.parse_group("atom", break_on_token_text)?;
286
287        if let Some(ref b) = base {
288            if b.type_name() == "internal" {
289                return Ok(base);
290            }
291        }
292
293        if self.mode == Mode::Text {
294            return Ok(base);
295        }
296
297        let mut superscript: Option<ParseNode> = None;
298        let mut subscript: Option<ParseNode> = None;
299
300        loop {
301            self.consume_spaces()?;
302            let lex = self.fetch()?;
303
304            if lex.text == "\\limits" || lex.text == "\\nolimits" {
305                let is_limits = lex.text == "\\limits";
306                self.consume();
307                if let Some(
308                    ParseNode::Op { limits, .. }
309                    | ParseNode::OperatorName { limits, .. },
310                ) = base.as_mut()
311                {
312                    *limits = is_limits;
313                }
314            } else if lex.text == "^" {
315                if superscript.is_some() {
316                    return Err(ParseError::new("Double superscript", Some(&lex)));
317                }
318                superscript = Some(self.handle_sup_subscript("superscript")?);
319            } else if lex.text == "_" {
320                if subscript.is_some() {
321                    return Err(ParseError::new("Double subscript", Some(&lex)));
322                }
323                subscript = Some(self.handle_sup_subscript("subscript")?);
324            } else if lex.text == "'" {
325                if superscript.is_some() {
326                    return Err(ParseError::new("Double superscript", Some(&lex)));
327                }
328                let prime = ParseNode::TextOrd {
329                    mode: self.mode,
330                    text: "\\prime".to_string(),
331                    loc: None,
332                };
333                let mut primes = vec![prime.clone()];
334                self.consume();
335                while self.fetch()?.text == "'" {
336                    primes.push(prime.clone());
337                    self.consume();
338                }
339                if self.fetch()?.text == "^" {
340                    primes.push(self.handle_sup_subscript("superscript")?);
341                }
342                superscript = Some(ParseNode::OrdGroup {
343                    mode: self.mode,
344                    body: primes,
345                    semisimple: None,
346                    loc: None,
347                });
348            } else if let Some((mapped, is_sub)) = lex
349                .text
350                .chars()
351                .next()
352                .and_then(crate::unicode_sup_sub::unicode_sub_sup)
353            {
354                if is_sub && subscript.is_some() {
355                    return Err(ParseError::new("Double subscript", Some(&lex)));
356                }
357                if !is_sub && superscript.is_some() {
358                    return Err(ParseError::new("Double superscript", Some(&lex)));
359                }
360                // Collect consecutive Unicode sup/sub chars of the same kind
361                let mut subsup_tokens = vec![Token::new(mapped, 0, 0)];
362                self.consume();
363                loop {
364                    let tok = self.fetch()?;
365                    match tok
366                        .text
367                        .chars()
368                        .next()
369                        .and_then(crate::unicode_sup_sub::unicode_sub_sup)
370                    {
371                        Some((m, sub)) if sub == is_sub => {
372                            subsup_tokens.insert(0, Token::new(m, 0, 0));
373                            self.consume();
374                        }
375                        _ => break,
376                    }
377                }
378                let body = self.subparse(subsup_tokens)?;
379                let group = ParseNode::OrdGroup {
380                    mode: Mode::Math,
381                    body,
382                    semisimple: None,
383                    loc: None,
384                };
385                if is_sub {
386                    subscript = Some(group);
387                } else {
388                    superscript = Some(group);
389                }
390            } else {
391                break;
392            }
393        }
394
395        if superscript.is_some() || subscript.is_some() {
396            Ok(Some(ParseNode::SupSub {
397                mode: self.mode,
398                base: base.map(Box::new),
399                sup: superscript.map(Box::new),
400                sub: subscript.map(Box::new),
401                loc: None,
402            }))
403        } else {
404            Ok(base)
405        }
406    }
407
408    /// Handle a subscript or superscript.
409    fn handle_sup_subscript(&mut self, name: &str) -> ParseResult<ParseNode> {
410        let symbol_token = self.fetch()?;
411        self.consume();
412        self.consume_spaces()?;
413
414        let group = self.parse_group(name, None)?;
415        match group {
416            Some(g) if g.type_name() != "internal" => Ok(g),
417            Some(_) => {
418                // Skip internal nodes, try again
419                let g2 = self.parse_group(name, None)?;
420                g2.ok_or_else(|| {
421                    ParseError::new(
422                        format!("Expected group after '{}'", symbol_token.text),
423                        Some(&symbol_token),
424                    )
425                })
426            }
427            None => Err(ParseError::new(
428                format!("Expected group after '{}'", symbol_token.text),
429                Some(&symbol_token),
430            )),
431        }
432    }
433
434    // ── Group parsing ───────────────────────────────────────────────────
435
436    /// Parse a group: braced expression, function call, or single symbol.
437    pub fn parse_group(
438        &mut self,
439        name: &str,
440        break_on_token_text: Option<&str>,
441    ) -> ParseResult<Option<ParseNode>> {
442        let first_token = self.fetch()?;
443        let text = first_token.text.clone();
444
445        if text == "{" || text == "\\begingroup" {
446            self.consume();
447            let group_end = if text == "{" { "}" } else { "\\endgroup" };
448
449            self.gullet.begin_group();
450            let expression = self.parse_expression(false, Some(group_end))?;
451            let last_token = self.fetch()?;
452            self.expect(group_end, true)?;
453            self.gullet.end_group();
454
455            let loc = Some(SourceLocation::range(&first_token.loc, &last_token.loc));
456            let semisimple = if text == "\\begingroup" {
457                Some(true)
458            } else {
459                None
460            };
461
462            Ok(Some(ParseNode::OrdGroup {
463                mode: self.mode,
464                body: expression,
465                semisimple,
466                loc,
467            }))
468        } else {
469            let result = self
470                .parse_function(break_on_token_text, Some(name))?
471                .or_else(|| self.parse_symbol_inner().ok().flatten());
472
473            if result.is_none()
474                && text.starts_with('\\')
475                && !IMPLICIT_COMMANDS.contains(&text.as_str())
476            {
477                return Err(ParseError::new(
478                    format!("Undefined control sequence: {}", text),
479                    Some(&first_token),
480                ));
481            }
482
483            Ok(result)
484        }
485    }
486
487    // ── Function parsing ────────────────────────────────────────────────
488
489    /// Try to parse a function call. Returns None if not a function.
490    pub fn parse_function(
491        &mut self,
492        break_on_token_text: Option<&str>,
493        name: Option<&str>,
494    ) -> ParseResult<Option<ParseNode>> {
495        let token = self.fetch()?;
496        let func = token.text.clone();
497
498        let func_data = match FUNCTIONS.get(func.as_str()) {
499            Some(f) => f,
500            None => return Ok(None),
501        };
502
503        self.consume();
504
505        if let Some(n) = name {
506            if n != "atom" && !func_data.allowed_in_argument {
507                return Err(ParseError::new(
508                    format!("Got function '{}' with no arguments as {}", func, n),
509                    Some(&token),
510                ));
511            }
512        }
513
514        functions::check_mode_compatibility(func_data, self.mode, &func, Some(&token))?;
515
516        // `\hspace*{len}` — `*` is a separate token (not part of the control word); consume it here.
517        // Must use gullet peek/pop only: `parser.fetch()` without `consume()` advances the lexer and
518        // leaves `{` only in `next_token`, so `parse_size_group`'s `gullet.future()` would miss the brace.
519        if func == "\\hspace" {
520            self.gullet.consume_spaces();
521            if self.gullet.future().text == "*" {
522                self.gullet.pop_token();
523            }
524        }
525
526        let (args, opt_args) = self.parse_arguments(&func, func_data)?;
527
528        self.call_function(
529            &func,
530            args,
531            opt_args,
532            Some(token),
533            break_on_token_text.map(|s| s.to_string()).as_deref(),
534        )
535        .map(Some)
536    }
537
538    /// Call a function handler.
539    pub fn call_function(
540        &mut self,
541        name: &str,
542        args: Vec<ParseNode>,
543        opt_args: Vec<Option<ParseNode>>,
544        token: Option<Token>,
545        break_on_token_text: Option<&str>,
546    ) -> ParseResult<ParseNode> {
547        let func = FUNCTIONS.get(name).ok_or_else(|| {
548            ParseError::msg(format!("No function handler for {}", name))
549        })?;
550
551        let mut ctx = FunctionContext {
552            func_name: name.to_string(),
553            parser: self,
554            token: token.clone(),
555            break_on_token_text: break_on_token_text.map(|s| s.to_string()),
556        };
557
558        (func.handler)(&mut ctx, args, opt_args)
559    }
560
561    /// Parse the arguments for a function.
562    pub fn parse_arguments(
563        &mut self,
564        func: &str,
565        func_data: &functions::FunctionSpec,
566    ) -> ParseResult<(Vec<ParseNode>, Vec<Option<ParseNode>>)> {
567        let total_args = func_data.num_args + func_data.num_optional_args;
568        if total_args == 0 {
569            return Ok((Vec::new(), Vec::new()));
570        }
571
572        let mut args = Vec::new();
573        let mut opt_args = Vec::new();
574
575        for i in 0..total_args {
576            let arg_type = func_data
577                .arg_types
578                .as_ref()
579                .and_then(|types| types.get(i).copied());
580            let is_optional = i < func_data.num_optional_args;
581
582            let effective_type = if (func_data.primitive && arg_type.is_none())
583                || (func_data.node_type == "sqrt" && i == 1
584                    && opt_args.first().is_some_and(|o: &Option<ParseNode>| o.is_none()))
585            {
586                Some(ArgType::Primitive)
587            } else {
588                arg_type
589            };
590
591            let arg = self.parse_group_of_type(
592                &format!("argument to '{}'", func),
593                effective_type,
594                is_optional,
595            )?;
596
597            if is_optional {
598                opt_args.push(arg);
599            } else if let Some(a) = arg {
600                args.push(a);
601            } else {
602                return Err(ParseError::msg("Null argument, please report this as a bug"));
603            }
604        }
605
606        Ok((args, opt_args))
607    }
608
609    /// Parse a group with a specific type.
610    fn parse_group_of_type(
611        &mut self,
612        name: &str,
613        arg_type: Option<ArgType>,
614        optional: bool,
615    ) -> ParseResult<Option<ParseNode>> {
616        match arg_type {
617            Some(ArgType::Color) => self.parse_color_group(optional),
618            Some(ArgType::Size) => self.parse_size_group(optional),
619            Some(ArgType::Primitive) => {
620                if optional {
621                    return Err(ParseError::msg("A primitive argument cannot be optional"));
622                }
623                let group = self.parse_group(name, None)?;
624                match group {
625                    Some(g) => Ok(Some(g)),
626                    None => Err(ParseError::new(
627                        format!("Expected group as {}", name),
628                        None,
629                    )),
630                }
631            }
632            Some(ArgType::Math) | Some(ArgType::Text) => {
633                let mode = match arg_type {
634                    Some(ArgType::Math) => Some(Mode::Math),
635                    Some(ArgType::Text) => Some(Mode::Text),
636                    _ => None,
637                };
638                self.parse_argument_group(optional, mode)
639            }
640            Some(ArgType::HBox) => {
641                let group = self.parse_argument_group(optional, Some(Mode::Text))?;
642                match group {
643                    Some(g) => Ok(Some(ParseNode::Styling {
644                        mode: g.mode(),
645                        style: crate::parse_node::StyleStr::Text,
646                        body: vec![g],
647                        loc: None,
648                    })),
649                    None => Ok(None),
650                }
651            }
652            Some(ArgType::Raw) => {
653                let token = self.parse_string_group("raw", optional)?;
654                match token {
655                    Some(t) => Ok(Some(ParseNode::Raw {
656                        mode: Mode::Text,
657                        string: t.text,
658                        loc: None,
659                    })),
660                    None => Ok(None),
661                }
662            }
663            Some(ArgType::Url) => self.parse_url_group(optional),
664            None | Some(ArgType::Original) => self.parse_argument_group(optional, None),
665        }
666    }
667
668    /// Parse a color group.
669    fn parse_color_group(&mut self, optional: bool) -> ParseResult<Option<ParseNode>> {
670        let res = self.parse_string_group("color", optional)?;
671        match res {
672            None => Ok(None),
673            Some(token) => {
674                let text = token.text.trim().to_string();
675                let re = regex_lite::Regex::new(
676                    r"^(#[a-fA-F0-9]{3,4}|#[a-fA-F0-9]{6}|#[a-fA-F0-9]{8}|[a-fA-F0-9]{6}|[a-zA-Z]+|\d+(\.\d+)?(,\d+(\.\d+)?)*)$",
677                )
678                .unwrap();
679
680                if !re.is_match(&text) {
681                    return Err(ParseError::new(
682                        format!("Invalid color: '{}'", text),
683                        Some(&token),
684                    ));
685                }
686                let mut color = text;
687                if regex_lite::Regex::new(r"^[0-9a-fA-F]{6}$")
688                    .unwrap()
689                    .is_match(&color)
690                {
691                    color = format!("#{}", color);
692                }
693
694                Ok(Some(ParseNode::ColorToken {
695                    mode: self.mode,
696                    color,
697                    loc: None,
698                }))
699            }
700        }
701    }
702
703    /// Parse a size group (e.g., "3pt", "1em").
704    pub fn parse_size_group(&mut self, optional: bool) -> ParseResult<Option<ParseNode>> {
705        let mut is_blank = false;
706
707        self.gullet.consume_spaces();
708        let res = if !optional && self.gullet.future().text != "{" {
709            Some(self.parse_regex_group(
710                &regex_lite::Regex::new(r"^[-+]? *(?:$|\d+|\d+\.\d*|\.\d*) *[a-z]{0,2} *$")
711                    .unwrap(),
712                "size",
713            )?)
714        } else {
715            self.parse_string_group("size", optional)?
716        };
717
718        let res = match res {
719            Some(r) => r,
720            None => return Ok(None),
721        };
722
723        let mut text = res.text.clone();
724        if !optional && text.is_empty() {
725            text = "0pt".to_string();
726            is_blank = true;
727        }
728
729        let size_re =
730            regex_lite::Regex::new(r"([-+]?) *(\d+(?:\.\d*)?|\.\d+) *([a-z]{2})").unwrap();
731        let m = size_re.captures(&text).ok_or_else(|| {
732            ParseError::new(format!("Invalid size: '{}'", text), Some(&res))
733        })?;
734
735        let sign = m.get(1).map_or("", |m| m.as_str());
736        let magnitude = m.get(2).map_or("", |m| m.as_str());
737        let unit = m.get(3).map_or("", |m| m.as_str());
738
739        let number: f64 = format!("{}{}", sign, magnitude).parse().unwrap_or(0.0);
740
741        if !is_valid_unit(unit) {
742            return Err(ParseError::new(
743                format!("Invalid unit: '{}'", unit),
744                Some(&res),
745            ));
746        }
747
748        Ok(Some(ParseNode::Size {
749            mode: self.mode,
750            value: crate::parse_node::Measurement {
751                number,
752                unit: unit.to_string(),
753            },
754            is_blank,
755            loc: None,
756        }))
757    }
758
759    /// Parse a URL group.
760    /// Temporarily disables `%` as comment character to allow `%20` etc. in URLs.
761    fn parse_url_group(&mut self, optional: bool) -> ParseResult<Option<ParseNode>> {
762        self.gullet.lexer.set_catcode('%', 13);
763        self.gullet.lexer.set_catcode('~', 12);
764        let res = self.parse_string_group("url", optional);
765        self.gullet.lexer.set_catcode('%', 14);
766        self.gullet.lexer.set_catcode('~', 13);
767        let res = res?;
768        match res {
769            None => Ok(None),
770            Some(token) => {
771                let url = token.text;
772                Ok(Some(ParseNode::Url {
773                    mode: self.mode,
774                    url,
775                    loc: None,
776                }))
777            }
778        }
779    }
780
781    /// Parse a string group (brace-enclosed string).
782    fn parse_string_group(
783        &mut self,
784        _mode_name: &str,
785        optional: bool,
786    ) -> ParseResult<Option<Token>> {
787        let arg_token = self.gullet.scan_argument(optional)?;
788        let arg_token = match arg_token {
789            Some(t) => t,
790            None => return Ok(None),
791        };
792
793        let mut s = String::new();
794        loop {
795            let next = self.fetch()?;
796            if next.text == "EOF" {
797                break;
798            }
799            s.push_str(&next.text);
800            self.consume();
801        }
802        self.consume(); // consume EOF
803
804        let mut result = arg_token;
805        result.text = s;
806        Ok(Some(result))
807    }
808
809    /// Parse a regex-delimited group.
810    fn parse_regex_group(
811        &mut self,
812        regex: &regex_lite::Regex,
813        mode_name: &str,
814    ) -> ParseResult<Token> {
815        let first_token = self.fetch()?;
816        let mut last_token = first_token.clone();
817        let mut s = String::new();
818
819        loop {
820            let next = self.fetch()?;
821            if next.text == "EOF" {
822                break;
823            }
824            let candidate = format!("{}{}", s, next.text);
825            if regex.is_match(&candidate) {
826                last_token = next;
827                s = candidate;
828                self.consume();
829            } else {
830                break;
831            }
832        }
833
834        if s.is_empty() {
835            return Err(ParseError::new(
836                format!("Invalid {}: '{}'", mode_name, first_token.text),
837                Some(&first_token),
838            ));
839        }
840
841        Ok(first_token.range(&last_token, s))
842    }
843
844    /// Parse an argument group (with optional mode switch).
845    pub fn parse_argument_group(
846        &mut self,
847        optional: bool,
848        mode: Option<Mode>,
849    ) -> ParseResult<Option<ParseNode>> {
850        let arg_token = self.gullet.scan_argument(optional)?;
851        let arg_token = match arg_token {
852            Some(t) => t,
853            None => return Ok(None),
854        };
855
856        let outer_mode = self.mode;
857        if let Some(m) = mode {
858            self.switch_mode(m);
859        }
860
861        self.gullet.begin_group();
862        let expression = self.parse_expression(false, Some("EOF"))?;
863        self.expect("EOF", true)?;
864        self.gullet.end_group();
865
866        let result = ParseNode::OrdGroup {
867            mode: self.mode,
868            loc: Some(arg_token.loc.clone()),
869            body: expression,
870            semisimple: None,
871        };
872
873        if mode.is_some() {
874            self.switch_mode(outer_mode);
875        }
876
877        Ok(Some(result))
878    }
879
880    // ── Symbol parsing ──────────────────────────────────────────────────
881
882    /// Parse a single symbol (internal version that returns Result).
883    fn parse_symbol_inner(&mut self) -> ParseResult<Option<ParseNode>> {
884        let nucleus = self.fetch()?;
885        let text = nucleus.text.clone();
886
887        if let Some(stripped) = text.strip_prefix("\\verb") {
888            self.consume();
889            let arg = stripped.to_string();
890            let star = arg.starts_with('*');
891            let arg = if star { &arg[1..] } else { &arg };
892
893            if arg.len() < 2 {
894                return Err(ParseError::new("\\verb assertion failed", Some(&nucleus)));
895            }
896            let body = arg[1..arg.len() - 1].to_string();
897            return Ok(Some(ParseNode::Verb {
898                mode: Mode::Text,
899                body,
900                star,
901                loc: Some(nucleus.loc.clone()),
902            }));
903        }
904
905        let font_mode = match self.mode {
906            Mode::Math => ratex_font::symbols::Mode::Math,
907            Mode::Text => ratex_font::symbols::Mode::Text,
908        };
909
910        // ^ and _ are handled by parse_atom for sup/sub, not as symbol nodes
911        if text == "^" || text == "_" {
912            return Ok(None);
913        }
914
915        // Bare backslash (incomplete control sequence) → not a valid symbol
916        if text == "\\" {
917            return Ok(None);
918        }
919
920        if let Some(sym_info) = ratex_font::symbols::get_symbol(&text, font_mode) {
921            let loc = Some(SourceLocation::range(&nucleus.loc, &nucleus.loc));
922            let group = sym_info.group;
923
924            let node = if group.is_atom() {
925                let family = match group {
926                    ratex_font::symbols::Group::Bin => AtomFamily::Bin,
927                    ratex_font::symbols::Group::Close => AtomFamily::Close,
928                    ratex_font::symbols::Group::Inner => AtomFamily::Inner,
929                    ratex_font::symbols::Group::Open => AtomFamily::Open,
930                    ratex_font::symbols::Group::Punct => AtomFamily::Punct,
931                    ratex_font::symbols::Group::Rel => AtomFamily::Rel,
932                    _ => unreachable!(),
933                };
934                ParseNode::Atom {
935                    mode: self.mode,
936                    family,
937                    text: text.clone(),
938                    loc,
939                }
940            } else {
941                match group {
942                    ratex_font::symbols::Group::MathOrd => ParseNode::MathOrd {
943                        mode: self.mode,
944                        text: text.clone(),
945                        loc,
946                    },
947                    ratex_font::symbols::Group::TextOrd => ParseNode::TextOrd {
948                        mode: self.mode,
949                        text: text.clone(),
950                        loc,
951                    },
952                    ratex_font::symbols::Group::OpToken => ParseNode::OpToken {
953                        mode: self.mode,
954                        text: text.clone(),
955                        loc,
956                    },
957                    ratex_font::symbols::Group::AccentToken => ParseNode::AccentToken {
958                        mode: self.mode,
959                        text: text.clone(),
960                        loc,
961                    },
962                    ratex_font::symbols::Group::Spacing => ParseNode::SpacingNode {
963                        mode: self.mode,
964                        text: text.clone(),
965                        loc,
966                    },
967                    _ => ParseNode::MathOrd {
968                        mode: self.mode,
969                        text: text.clone(),
970                        loc,
971                    },
972                }
973            };
974
975            self.consume();
976            return Ok(Some(node));
977        }
978
979        // Unicode accented characters → decompose into accent nodes
980        // Handles both precomposed (á U+00E1) and combining forms (a + U+0301)
981        if let Some(node) = self.try_parse_unicode_accent(&text, &nucleus)? {
982            self.consume();
983            return Ok(Some(node));
984        }
985
986        // Non-ASCII characters without accent decomposition → treat as textord
987        // KaTeX always uses mode="text" for these, regardless of current mode
988        let first_char = text.chars().next();
989        if let Some(ch) = first_char {
990            if ch as u32 >= 0x80 {
991                let node = ParseNode::TextOrd {
992                    mode: Mode::Text,
993                    text: text.clone(),
994                    loc: Some(SourceLocation::range(&nucleus.loc, &nucleus.loc)),
995                };
996                self.consume();
997                return Ok(Some(node));
998            }
999        }
1000
1001        Ok(None)
1002    }
1003
1004    /// Try to decompose a Unicode accented character into accent nodes.
1005    /// Returns None if no decomposition is available.
1006    /// Only decomposes Latin-script characters, matching KaTeX behavior.
1007    fn try_parse_unicode_accent(
1008        &self,
1009        text: &str,
1010        nucleus: &Token,
1011    ) -> ParseResult<Option<ParseNode>> {
1012        let nfd: String = text.nfd().collect();
1013        let chars: Vec<char> = nfd.chars().collect();
1014
1015        if chars.len() < 2 {
1016            return Ok(None);
1017        }
1018
1019        // Build from the base up through each combining mark
1020        let mut split_idx = chars.len() - 1;
1021        while split_idx > 0 && is_supported_combining_accent(chars[split_idx]) {
1022            split_idx -= 1;
1023        }
1024
1025        // Verify ALL trailing chars are supported combining accents
1026        if split_idx == chars.len() - 1 {
1027            return Ok(None);
1028        }
1029
1030        // Only decompose Latin-script base characters
1031        let base_char = chars[0];
1032        if !is_latin_base_char(base_char) {
1033            return Ok(None);
1034        }
1035
1036        let loc = Some(SourceLocation::range(&nucleus.loc, &nucleus.loc));
1037
1038        // Base: everything before the combining marks
1039        let mut base_str: String = chars[..split_idx + 1].iter().collect();
1040
1041        // Accented i→ı and j→ȷ (dotless variants), matching KaTeX behavior
1042        if base_str.len() == 1 {
1043            match base_str.as_str() {
1044                "i" => base_str = "\u{0131}".to_string(), // ı
1045                "j" => base_str = "\u{0237}".to_string(), // ȷ
1046                _ => {}
1047            }
1048        }
1049
1050        let font_mode = match self.mode {
1051            Mode::Math => ratex_font::symbols::Mode::Math,
1052            Mode::Text => ratex_font::symbols::Mode::Text,
1053        };
1054
1055        let mut node = if base_str.chars().count() == 1 {
1056            let ch = base_str.chars().next().unwrap();
1057            if let Some(sym) = ratex_font::symbols::get_symbol(&base_str, font_mode) {
1058                match sym.group {
1059                    ratex_font::symbols::Group::TextOrd => ParseNode::TextOrd {
1060                        mode: self.mode,
1061                        text: base_str.clone(),
1062                        loc: loc.clone(),
1063                    },
1064                    _ => ParseNode::MathOrd {
1065                        mode: self.mode,
1066                        text: base_str.clone(),
1067                        loc: loc.clone(),
1068                    },
1069                }
1070            } else if (ch as u32) >= 0x80 {
1071                // Non-ASCII base chars always text mode (KaTeX compat)
1072                ParseNode::TextOrd {
1073                    mode: Mode::Text,
1074                    text: base_str.clone(),
1075                    loc: loc.clone(),
1076                }
1077            } else {
1078                ParseNode::MathOrd {
1079                    mode: self.mode,
1080                    text: base_str.clone(),
1081                    loc: loc.clone(),
1082                }
1083            }
1084        } else {
1085            return self.try_parse_unicode_accent(&base_str, nucleus).map(|opt| {
1086                opt.or_else(|| {
1087                    Some(ParseNode::TextOrd {
1088                        mode: Mode::Text,
1089                        text: base_str.clone(),
1090                        loc: loc.clone(),
1091                    })
1092                })
1093            });
1094        };
1095
1096        // Wrap in accent nodes from innermost to outermost
1097        for &combining in &chars[split_idx + 1..] {
1098            let label = combining_to_accent_label(combining, self.mode);
1099            node = ParseNode::Accent {
1100                mode: self.mode,
1101                label,
1102                is_stretchy: Some(false),
1103                is_shifty: Some(true),
1104                base: Box::new(node),
1105                loc: loc.clone(),
1106            };
1107        }
1108
1109        Ok(Some(node))
1110    }
1111
1112    /// Parse a sub-expression from the given tokens.
1113    pub fn subparse(&mut self, tokens: Vec<Token>) -> ParseResult<Vec<ParseNode>> {
1114        let old_token = self.next_token.take();
1115
1116        self.gullet
1117            .push_token(Token::new("}", 0, 0));
1118        self.gullet.push_tokens(tokens);
1119        let parse = self.parse_expression(false, None)?;
1120        self.expect("}", true)?;
1121
1122        self.next_token = old_token;
1123        Ok(parse)
1124    }
1125}
1126
1127fn is_latin_base_char(ch: char) -> bool {
1128    matches!(ch,
1129        'A'..='Z' | 'a'..='z'
1130        | '\u{0131}' // ı (dotless i)
1131        | '\u{0237}' // ȷ (dotless j)
1132        | '\u{00C6}' // Æ
1133        | '\u{00D0}' // Ð
1134        | '\u{00D8}' // Ø
1135        | '\u{00DE}' // Þ
1136        | '\u{00DF}' // ß
1137        | '\u{00E6}' // æ
1138        | '\u{00F0}' // ð
1139        | '\u{00F8}' // ø
1140        | '\u{00FE}' // þ
1141    )
1142}
1143
1144fn is_supported_combining_accent(ch: char) -> bool {
1145    matches!(
1146        ch,
1147        '\u{0300}' | '\u{0301}' | '\u{0302}' | '\u{0303}' | '\u{0304}'
1148        | '\u{0306}' | '\u{0307}' | '\u{0308}' | '\u{030A}' | '\u{030B}' | '\u{030C}'
1149        | '\u{0327}'
1150    )
1151}
1152
1153fn combining_to_accent_label(ch: char, mode: Mode) -> String {
1154    match mode {
1155        Mode::Math => match ch {
1156            '\u{0300}' => "\\grave".to_string(),
1157            '\u{0301}' => "\\acute".to_string(),
1158            '\u{0302}' => "\\hat".to_string(),
1159            '\u{0303}' => "\\tilde".to_string(),
1160            '\u{0304}' => "\\bar".to_string(),
1161            '\u{0306}' => "\\breve".to_string(),
1162            '\u{0307}' => "\\dot".to_string(),
1163            '\u{0308}' => "\\ddot".to_string(),
1164            '\u{030A}' => "\\mathring".to_string(),
1165            '\u{030B}' => "\\H".to_string(),
1166            '\u{030C}' => "\\check".to_string(),
1167            '\u{0327}' => "\\c".to_string(),
1168            _ => format!("\\char\"{:X}", ch as u32),
1169        },
1170        Mode::Text => match ch {
1171            '\u{0300}' => "\\`".to_string(),
1172            '\u{0301}' => "\\'".to_string(),
1173            '\u{0302}' => "\\^".to_string(),
1174            '\u{0303}' => "\\~".to_string(),
1175            '\u{0304}' => "\\=".to_string(),
1176            '\u{0306}' => "\\u".to_string(),
1177            '\u{0307}' => "\\.".to_string(),
1178            '\u{0308}' => "\\\"".to_string(),
1179            '\u{030A}' => "\\r".to_string(),
1180            '\u{030B}' => "\\H".to_string(),
1181            '\u{030C}' => "\\v".to_string(),
1182            '\u{0327}' => "\\c".to_string(),
1183            _ => format!("\\char\"{:X}", ch as u32),
1184        },
1185    }
1186}
1187
1188fn is_valid_unit(unit: &str) -> bool {
1189    matches!(
1190        unit,
1191        "pt" | "mm" | "cm" | "in" | "bp" | "pc" | "dd" | "cc" | "nd" | "nc" | "sp" | "px"
1192            | "ex" | "em" | "mu"
1193    )
1194}
1195
1196/// If the whole expression is wrapped in TeX inline/display math delimiters, parse the inside only.
1197/// The parser already runs in math mode; a leading `$` would otherwise hit the `$` / `\\(` "switch to math"
1198/// handler, which is disallowed in math mode (see `functions::math`).
1199fn strip_outer_math_delimiters(input: &str) -> &str {
1200    let s = input.trim();
1201    if s.len() >= 4 && s.starts_with("$$") && s.ends_with("$$") {
1202        return s[2..s.len() - 2].trim();
1203    }
1204    if s.len() >= 2 && s.starts_with('$') && s.ends_with('$') {
1205        return s[1..s.len() - 1].trim();
1206    }
1207    s
1208}
1209
1210/// Convenience function: parse a LaTeX string and return the AST.
1211pub fn parse(input: &str) -> ParseResult<Vec<ParseNode>> {
1212    Parser::new(strip_outer_math_delimiters(input)).parse()
1213}
1214
1215#[cfg(test)]
1216mod tests {
1217    use super::*;
1218
1219    #[test]
1220    fn test_parse_single_char() {
1221        let result = parse("x").unwrap();
1222        assert_eq!(result.len(), 1);
1223        assert_eq!(result[0].type_name(), "mathord");
1224    }
1225
1226    #[test]
1227    fn test_parse_strips_outer_dollar_inline_math() {
1228        let inner = r"C_p[\ce{H2O(l)}] = \pu{75.3 J // mol K}";
1229        let wrapped = format!("${inner}$");
1230        let a = parse(&wrapped).expect("wrapped");
1231        let b = parse(inner).expect("inner");
1232        assert_eq!(a.len(), b.len());
1233        for (x, y) in a.iter().zip(b.iter()) {
1234            assert_eq!(x.type_name(), y.type_name());
1235        }
1236    }
1237
1238    #[test]
1239    fn test_parse_addition() {
1240        let result = parse("a+b").unwrap();
1241        assert_eq!(result.len(), 3);
1242        assert_eq!(result[0].type_name(), "mathord"); // a
1243        assert_eq!(result[1].type_name(), "atom"); // +
1244        assert_eq!(result[2].type_name(), "mathord"); // b
1245    }
1246
1247    #[test]
1248    fn test_parse_superscript() {
1249        let result = parse("x^2").unwrap();
1250        assert_eq!(result.len(), 1);
1251        assert_eq!(result[0].type_name(), "supsub");
1252    }
1253
1254    #[test]
1255    fn test_parse_subscript() {
1256        let result = parse("a_i").unwrap();
1257        assert_eq!(result.len(), 1);
1258        assert_eq!(result[0].type_name(), "supsub");
1259    }
1260
1261    #[test]
1262    fn test_parse_supsub() {
1263        let result = parse("x^2_i").unwrap();
1264        assert_eq!(result.len(), 1);
1265        assert_eq!(result[0].type_name(), "supsub");
1266        if let ParseNode::SupSub { sup, sub, .. } = &result[0] {
1267            assert!(sup.is_some());
1268            assert!(sub.is_some());
1269        } else {
1270            panic!("Expected SupSub");
1271        }
1272    }
1273
1274    #[test]
1275    fn test_parse_group() {
1276        let result = parse("{a+b}").unwrap();
1277        assert_eq!(result.len(), 1);
1278        assert_eq!(result[0].type_name(), "ordgroup");
1279    }
1280
1281    #[test]
1282    fn test_parse_frac() {
1283        let result = parse("\\frac{a}{b}").unwrap();
1284        assert_eq!(result.len(), 1);
1285        assert_eq!(result[0].type_name(), "genfrac");
1286    }
1287
1288    #[test]
1289    fn test_parse_sqrt() {
1290        let result = parse("\\sqrt{x}").unwrap();
1291        assert_eq!(result.len(), 1);
1292        assert_eq!(result[0].type_name(), "sqrt");
1293    }
1294
1295    #[test]
1296    fn test_parse_sqrt_optional() {
1297        let result = parse("\\sqrt[3]{x}").unwrap();
1298        assert_eq!(result.len(), 1);
1299        if let ParseNode::Sqrt { index, .. } = &result[0] {
1300            assert!(index.is_some());
1301        } else {
1302            panic!("Expected Sqrt");
1303        }
1304    }
1305
1306    #[test]
1307    fn test_parse_nested() {
1308        let result = parse("\\frac{\\sqrt{a^2+b^2}}{c}").unwrap();
1309        assert_eq!(result.len(), 1);
1310        assert_eq!(result[0].type_name(), "genfrac");
1311    }
1312
1313    #[test]
1314    fn test_parse_empty() {
1315        let result = parse("").unwrap();
1316        assert_eq!(result.len(), 0);
1317    }
1318
1319    #[test]
1320    fn test_parse_double_superscript_error() {
1321        let result = parse("x^2^3");
1322        assert!(result.is_err());
1323    }
1324
1325    #[test]
1326    fn test_parse_unclosed_brace_error() {
1327        let result = parse("{x");
1328        assert!(result.is_err());
1329    }
1330
1331    #[test]
1332    fn test_parse_json_output() {
1333        let result = parse("x^2").unwrap();
1334        let json = serde_json::to_string_pretty(&result).unwrap();
1335        assert!(json.contains("supsub"));
1336    }
1337}