Skip to main content

ratex_parser/
environments.rs

1use std::collections::HashMap;
2
3use ratex_lexer::token::Token;
4
5use crate::error::{ParseError, ParseResult};
6use crate::macro_expander::MacroDefinition;
7use crate::parse_node::{
8    AlignSpec, AlignType, ArrayTag, Measurement, Mode, ParseNode, ProofBranch, ProofLineStyle,
9    StyleStr,
10};
11use crate::parser::Parser;
12
13// ── Environment registry ─────────────────────────────────────────────────
14
15pub struct EnvContext<'a, 'b> {
16    pub mode: Mode,
17    pub env_name: String,
18    pub parser: &'a mut Parser<'b>,
19}
20
21pub type EnvHandler = fn(
22    ctx: &mut EnvContext,
23    args: Vec<ParseNode>,
24    opt_args: Vec<Option<ParseNode>>,
25) -> ParseResult<ParseNode>;
26
27pub struct EnvSpec {
28    pub num_args: usize,
29    pub num_optional_args: usize,
30    pub handler: EnvHandler,
31}
32
33pub static ENVIRONMENTS: std::sync::LazyLock<HashMap<&'static str, EnvSpec>> =
34    std::sync::LazyLock::new(|| {
35        let mut map = HashMap::new();
36        register_array(&mut map);
37        register_matrix(&mut map);
38        register_cases(&mut map);
39        register_align(&mut map);
40        register_gathered(&mut map);
41        register_equation(&mut map);
42        register_smallmatrix(&mut map);
43        register_alignat(&mut map);
44        register_subarray(&mut map);
45        register_cd(&mut map);
46        register_prooftree(&mut map);
47        map
48    });
49
50// ── ArrayConfig ──────────────────────────────────────────────────────────
51
52#[derive(Default)]
53pub struct ArrayConfig {
54    pub hskip_before_and_after: Option<bool>,
55    pub add_jot: Option<bool>,
56    pub cols: Option<Vec<AlignSpec>>,
57    pub arraystretch: Option<f64>,
58    pub col_separation_type: Option<String>,
59    pub single_row: bool,
60    pub empty_single_row: bool,
61    pub max_num_cols: Option<usize>,
62    pub leqno: Option<bool>,
63    pub auto_number: bool,
64}
65
66
67// ── parseArray ───────────────────────────────────────────────────────────
68
69/// Pull a trailing `\\tag{…}` or `\\nonumber`/`\\notag` off the last cell of a row.
70/// Returns `Auto(true)` when the row is eligible for auto-numbering.
71/// The `auto_number` parameter controls the default when no marker is found.
72fn extract_trailing_tag_from_last_cell(row: &mut [ParseNode], auto_number: bool) -> ParseResult<ArrayTag> {
73    let default_tag = if auto_number { ArrayTag::Auto(true) } else { ArrayTag::Auto(false) };
74    let Some(last) = row.last_mut() else {
75        return Ok(default_tag);
76    };
77
78    let inner: &mut ParseNode = match last {
79        ParseNode::Styling { body, .. } => {
80            if body.len() != 1 {
81                return Ok(default_tag);
82            }
83            &mut body[0]
84        }
85        _ => last,
86    };
87
88    let obody = match inner {
89        ParseNode::OrdGroup { body, .. } => body,
90        _ => return Ok(default_tag),
91    };
92
93    // Look for \\tag
94    let tag_indices: Vec<usize> = obody
95        .iter()
96        .enumerate()
97        .filter(|(_, n)| matches!(n, ParseNode::Tag { .. }))
98        .map(|(i, _)| i)
99        .collect();
100
101    // Look for \\nonumber / \\notag
102    let nonumber_indices: Vec<usize> = obody
103        .iter()
104        .enumerate()
105        .filter(|(_, n)| matches!(n, ParseNode::NoNumber { .. }))
106        .map(|(i, _)| i)
107        .collect();
108
109    // Can't have both \\tag and \\nonumber in the same row
110    if !tag_indices.is_empty() && !nonumber_indices.is_empty() {
111        return Err(ParseError::msg(
112            "Cannot use both \\tag and \\nonumber in the same row",
113        ));
114    }
115
116    // Handle \\tag
117    if !tag_indices.is_empty() {
118        if tag_indices.len() > 1 {
119            return Err(ParseError::msg("Multiple \\tag in a row"));
120        }
121        let idx = tag_indices[0];
122        if idx != obody.len() - 1 {
123            return Err(ParseError::msg(
124                "\\tag must appear at the end of the row after the equation body",
125            ));
126        }
127        match obody.pop() {
128            Some(ParseNode::Tag { tag, .. }) => {
129                if tag.is_empty() {
130                    Ok(ArrayTag::Auto(false))
131                } else {
132                    Ok(ArrayTag::Explicit(tag))
133                }
134            }
135            _ => Ok(default_tag),
136        }
137    } else if !nonumber_indices.is_empty() {
138        // Handle \\nonumber / \\notag
139        if nonumber_indices.len() > 1 {
140            return Err(ParseError::msg("Multiple \\nonumber in a row"));
141        }
142        let idx = nonumber_indices[0];
143        if idx != obody.len() - 1 {
144            return Err(ParseError::msg(
145                "\\nonumber must appear at the end of the row",
146            ));
147        }
148        obody.pop(); // discard the NoNumber node
149        Ok(ArrayTag::Auto(false))
150    } else {
151        // Neither \\tag nor \\nonumber
152        Ok(default_tag)
153    }
154}
155
156fn get_hlines(parser: &mut Parser) -> ParseResult<Vec<bool>> {
157    let mut hline_info = Vec::new();
158    parser.consume_spaces()?;
159
160    let mut nxt = parser.fetch()?.text.clone();
161    if nxt == "\\relax" {
162        parser.consume();
163        parser.consume_spaces()?;
164        nxt = parser.fetch()?.text.clone();
165    }
166    while nxt == "\\hline" || nxt == "\\hdashline" {
167        parser.consume();
168        hline_info.push(nxt == "\\hdashline");
169        parser.consume_spaces()?;
170        nxt = parser.fetch()?.text.clone();
171    }
172    Ok(hline_info)
173}
174
175fn d_cell_style(env_name: &str) -> Option<StyleStr> {
176    if env_name.starts_with('d') {
177        Some(StyleStr::Display)
178    } else {
179        Some(StyleStr::Text)
180    }
181}
182
183pub fn parse_array(
184    parser: &mut Parser,
185    config: ArrayConfig,
186    style: Option<StyleStr>,
187) -> ParseResult<ParseNode> {
188    parser.gullet.begin_group();
189
190    if !config.single_row {
191        parser
192            .gullet
193            .set_text_macro("\\cr", "\\\\\\relax");
194    }
195
196    let arraystretch = config.arraystretch.unwrap_or_else(|| {
197        // Check if \arraystretch is defined as a macro (e.g., via \def\arraystretch{1.5})
198        if let Some(def) = parser.gullet.get_macro("\\arraystretch") {
199            let s = match def {
200                MacroDefinition::Text(s) => s.clone(),
201                MacroDefinition::Tokens { tokens, .. } => {
202                    // Tokens are stored in reverse order (stack convention for expansion)
203                    tokens.iter().rev().map(|t| t.text.as_str()).collect::<String>()
204                }
205                MacroDefinition::Function(_) => String::new(),
206            };
207            s.parse::<f64>().unwrap_or(1.0)
208        } else {
209            1.0
210        }
211    });
212
213    parser.gullet.begin_group();
214
215    let mut row: Vec<ParseNode> = Vec::new();
216    let mut body: Vec<Vec<ParseNode>> = Vec::new();
217    let mut row_tags: Vec<ArrayTag> = Vec::new();
218    let mut row_gaps: Vec<Option<Measurement>> = Vec::new();
219    let mut hlines_before_row: Vec<Vec<bool>> = Vec::new();
220
221    hlines_before_row.push(get_hlines(parser)?);
222
223    loop {
224        let break_token = if config.single_row { "\\end" } else { "\\\\" };
225        let cell_body = parser.parse_expression(false, Some(break_token))?;
226        parser.gullet.end_group();
227        parser.gullet.begin_group();
228
229        let mut cell = ParseNode::OrdGroup {
230            mode: parser.mode,
231            body: cell_body,
232            semisimple: None,
233            loc: None,
234        };
235
236        if let Some(s) = style {
237            cell = ParseNode::Styling {
238                mode: parser.mode,
239                style: s,
240                body: vec![cell],
241                loc: None,
242            };
243        }
244
245        row.push(cell.clone());
246        let next = parser.fetch()?.text.clone();
247
248        if next == "&" {
249            if let Some(max) = config.max_num_cols {
250                if row.len() >= max {
251                    return Err(ParseError::msg("Too many tab characters: &"));
252                }
253            }
254            parser.consume();
255        } else if next == "\\end" {
256            // Check for trailing empty row and remove it
257            let is_empty_trailing = if let Some(s) = style {
258                if s == StyleStr::Text || s == StyleStr::Display {
259                    if let ParseNode::Styling { body: ref sb, .. } = cell {
260                        if let Some(ParseNode::OrdGroup {
261                            body: ref ob, ..
262                        }) = sb.first()
263                        {
264                            ob.is_empty()
265                        } else {
266                            false
267                        }
268                    } else {
269                        false
270                    }
271                } else {
272                    false
273                }
274            } else if let ParseNode::OrdGroup { body: ref ob, .. } = cell {
275                ob.is_empty()
276            } else {
277                false
278            };
279
280            let row_tag = extract_trailing_tag_from_last_cell(&mut row, config.auto_number)?;
281            row_tags.push(row_tag);
282            body.push(row);
283
284            if is_empty_trailing
285                && (body.len() > 1 || !config.empty_single_row)
286            {
287                body.pop();
288                row_tags.pop();
289            }
290
291            if hlines_before_row.len() < body.len() + 1 {
292                hlines_before_row.push(vec![]);
293            }
294            break;
295        } else if next == "\\\\" {
296            parser.consume();
297            let size = if parser.gullet.future().text != " " {
298                parser.parse_size_group(true)?
299            } else {
300                None
301            };
302            let gap = size.and_then(|s| {
303                if let ParseNode::Size { value, .. } = s {
304                    Some(value)
305                } else {
306                    None
307                }
308            });
309            row_gaps.push(gap);
310
311            let row_tag = extract_trailing_tag_from_last_cell(&mut row, config.auto_number)?;
312            row_tags.push(row_tag);
313            body.push(row);
314            hlines_before_row.push(get_hlines(parser)?);
315            row = Vec::new();
316        } else {
317            return Err(ParseError::msg(format!(
318                "Expected & or \\\\ or \\cr or \\end, got '{}'",
319                next
320            )));
321        }
322    }
323
324    parser.gullet.end_group();
325    parser.gullet.end_group();
326
327    // Post-process row tags for auto-numbering
328    let tags = if config.auto_number {
329        let mut processed: Vec<ArrayTag> = Vec::with_capacity(row_tags.len());
330        let mut any_visible = false;
331        for raw_tag in &row_tags {
332            match raw_tag {
333                ArrayTag::Explicit(nodes) if !nodes.is_empty() => {
334                    // Explicit \\tag{...}: step counter, keep tag content as-is
335                    parser.equation_counter += 1;
336                    processed.push(ArrayTag::Explicit(nodes.clone()));
337                    any_visible = true;
338                }
339                ArrayTag::Explicit(_) => {
340                    // Empty \\tag{}: treat as suppressed
341                    processed.push(ArrayTag::Auto(false));
342                }
343                ArrayTag::Auto(true) => {
344                    // Auto-number this row: step counter, generate "(N)"
345                    parser.equation_counter += 1;
346                    let num_str = parser.equation_counter.to_string();
347                    let tag_nodes = vec![
348                        ParseNode::MathOrd {
349                            mode: Mode::Math,
350                            text: "(".to_string(),
351                            loc: None,
352                        },
353                        ParseNode::MathOrd {
354                            mode: Mode::Math,
355                            text: num_str,
356                            loc: None,
357                        },
358                        ParseNode::MathOrd {
359                            mode: Mode::Math,
360                            text: ")".to_string(),
361                            loc: None,
362                        },
363                    ];
364                    processed.push(ArrayTag::Explicit(tag_nodes));
365                    any_visible = true;
366                }
367                ArrayTag::Auto(false) => {
368                    // Suppressed by \\nonumber or empty \\tag{}: no counter step, no tag
369                    processed.push(ArrayTag::Auto(false));
370                }
371            }
372        }
373        if any_visible { Some(processed) } else { None }
374    } else {
375        // Not an auto-numbering environment: keep original behavior
376        if row_tags.iter().any(|t| {
377            matches!(t, ArrayTag::Explicit(nodes) if !nodes.is_empty())
378        }) {
379            Some(row_tags)
380        } else {
381            None
382        }
383    };
384
385    Ok(ParseNode::Array {
386        mode: parser.mode,
387        body,
388        row_gaps,
389        hlines_before_row,
390        cols: config.cols,
391        col_separation_type: config.col_separation_type,
392        hskip_before_and_after: config.hskip_before_and_after,
393        add_jot: config.add_jot,
394        arraystretch,
395        tags,
396        leqno: config.leqno,
397        is_cd: None,
398        loc: None,
399    })
400}
401
402// ── array / darray ───────────────────────────────────────────────────────
403
404fn register_array(map: &mut HashMap<&'static str, EnvSpec>) {
405    fn handle_array(
406        ctx: &mut EnvContext,
407        args: Vec<ParseNode>,
408        _opt_args: Vec<Option<ParseNode>>,
409    ) -> ParseResult<ParseNode> {
410        let colalign = match &args[0] {
411            ParseNode::OrdGroup { body, .. } => body.clone(),
412            other if other.is_symbol_node() => vec![other.clone()],
413            _ => return Err(ParseError::msg("Invalid column alignment for array")),
414        };
415
416        let mut cols = Vec::new();
417        for nde in &colalign {
418            let ca = nde
419                .symbol_text()
420                .ok_or_else(|| ParseError::msg("Expected column alignment character"))?;
421            match ca {
422                "l" | "c" | "r" => cols.push(AlignSpec {
423                    align_type: AlignType::Align,
424                    align: Some(ca.to_string()),
425                    pregap: None,
426                    postgap: None,
427                }),
428                "|" => cols.push(AlignSpec {
429                    align_type: AlignType::Separator,
430                    align: Some("|".to_string()),
431                    pregap: None,
432                    postgap: None,
433                }),
434                ":" => cols.push(AlignSpec {
435                    align_type: AlignType::Separator,
436                    align: Some(":".to_string()),
437                    pregap: None,
438                    postgap: None,
439                }),
440                _ => {
441                    return Err(ParseError::msg(format!(
442                        "Unknown column alignment: {}",
443                        ca
444                    )))
445                }
446            }
447        }
448
449        let max_num_cols = cols.len();
450        let config = ArrayConfig {
451            cols: Some(cols),
452            hskip_before_and_after: Some(true),
453            max_num_cols: Some(max_num_cols),
454            ..Default::default()
455        };
456        parse_array(ctx.parser, config, d_cell_style(&ctx.env_name))
457    }
458
459    for name in &["array", "darray"] {
460        map.insert(
461            name,
462            EnvSpec {
463                num_args: 1,
464                num_optional_args: 0,
465                handler: handle_array,
466            },
467        );
468    }
469}
470
471// ── matrix variants ──────────────────────────────────────────────────────
472
473fn register_matrix(map: &mut HashMap<&'static str, EnvSpec>) {
474    fn handle_matrix(
475        ctx: &mut EnvContext,
476        _args: Vec<ParseNode>,
477        _opt_args: Vec<Option<ParseNode>>,
478    ) -> ParseResult<ParseNode> {
479        let base_name = ctx.env_name.replace('*', "");
480        let delimiters: Option<(&str, &str)> = match base_name.as_str() {
481            "matrix" => None,
482            "pmatrix" => Some(("(", ")")),
483            "bmatrix" => Some(("[", "]")),
484            "Bmatrix" => Some(("\\{", "\\}")),
485            "vmatrix" => Some(("|", "|")),
486            "Vmatrix" => Some(("\\Vert", "\\Vert")),
487            _ => None,
488        };
489
490        let mut col_align = "c".to_string();
491
492        // mathtools starred matrix: parse optional [l|c|r] alignment
493        if ctx.env_name.ends_with('*') {
494            ctx.parser.gullet.consume_spaces();
495            if ctx.parser.gullet.future().text == "[" {
496                ctx.parser.gullet.pop_token();
497                ctx.parser.gullet.consume_spaces();
498                let align_tok = ctx.parser.gullet.pop_token();
499                if !"lcr".contains(align_tok.text.as_str()) {
500                    return Err(ParseError::new(
501                        "Expected l or c or r".to_string(),
502                        Some(&align_tok),
503                    ));
504                }
505                col_align = align_tok.text.clone();
506                ctx.parser.gullet.consume_spaces();
507                let close = ctx.parser.gullet.pop_token();
508                if close.text != "]" {
509                    return Err(ParseError::new(
510                        "Expected ]".to_string(),
511                        Some(&close),
512                    ));
513                }
514            }
515        }
516
517        let config = ArrayConfig {
518            hskip_before_and_after: Some(false),
519            cols: Some(vec![AlignSpec {
520                align_type: AlignType::Align,
521                align: Some(col_align.clone()),
522                pregap: None,
523                postgap: None,
524            }]),
525            ..Default::default()
526        };
527
528        let mut res = parse_array(ctx.parser, config, d_cell_style(&ctx.env_name))?;
529
530        // Fix cols to match actual number of columns
531        if let ParseNode::Array {
532            ref body,
533            ref mut cols,
534            ..
535        } = res
536        {
537            let num_cols = body.iter().map(|r| r.len()).max().unwrap_or(0);
538            *cols = Some(
539                (0..num_cols)
540                    .map(|_| AlignSpec {
541                        align_type: AlignType::Align,
542                        align: Some(col_align.to_string()),
543                        pregap: None,
544                        postgap: None,
545                    })
546                    .collect(),
547            );
548        }
549
550        match delimiters {
551            Some((left, right)) => Ok(ParseNode::LeftRight {
552                mode: ctx.mode,
553                body: vec![res],
554                left: left.to_string(),
555                right: right.to_string(),
556                right_color: None,
557                loc: None,
558            }),
559            None => Ok(res),
560        }
561    }
562
563    for name in &[
564        "matrix", "pmatrix", "bmatrix", "Bmatrix", "vmatrix", "Vmatrix",
565        "matrix*", "pmatrix*", "bmatrix*", "Bmatrix*", "vmatrix*", "Vmatrix*",
566    ] {
567        map.insert(
568            name,
569            EnvSpec {
570                num_args: 0,
571                num_optional_args: 0,
572                handler: handle_matrix,
573            },
574        );
575    }
576}
577
578// ── cases / dcases / rcases / drcases ────────────────────────────────────
579
580fn register_cases(map: &mut HashMap<&'static str, EnvSpec>) {
581    fn handle_cases(
582        ctx: &mut EnvContext,
583        _args: Vec<ParseNode>,
584        _opt_args: Vec<Option<ParseNode>>,
585    ) -> ParseResult<ParseNode> {
586        let config = ArrayConfig {
587            arraystretch: Some(1.2),
588            cols: Some(vec![
589                AlignSpec {
590                    align_type: AlignType::Align,
591                    align: Some("l".to_string()),
592                    pregap: Some(0.0),
593                    postgap: Some(1.0),
594                },
595                AlignSpec {
596                    align_type: AlignType::Align,
597                    align: Some("l".to_string()),
598                    pregap: Some(0.0),
599                    postgap: Some(0.0),
600                },
601            ]),
602            ..Default::default()
603        };
604
605        let res = parse_array(ctx.parser, config, d_cell_style(&ctx.env_name))?;
606
607        let (left, right) = if ctx.env_name.contains('r') {
608            (".", "\\}")
609        } else {
610            ("\\{", ".")
611        };
612
613        Ok(ParseNode::LeftRight {
614            mode: ctx.mode,
615            body: vec![res],
616            left: left.to_string(),
617            right: right.to_string(),
618            right_color: None,
619            loc: None,
620        })
621    }
622
623    for name in &["cases", "dcases", "rcases", "drcases"] {
624        map.insert(
625            name,
626            EnvSpec {
627                num_args: 0,
628                num_optional_args: 0,
629                handler: handle_cases,
630            },
631        );
632    }
633}
634
635// ── align / align* / aligned / split / alignat / alignat* / alignedat ────
636
637fn handle_aligned(
638    ctx: &mut EnvContext,
639    args: Vec<ParseNode>,
640    _opt_args: Vec<Option<ParseNode>>,
641) -> ParseResult<ParseNode> {
642        let is_split = ctx.env_name == "split";
643        let is_alignat = ctx.env_name.contains("at");
644        let sep_type = if is_alignat { "alignat" } else { "align" };
645        let auto_number = !ctx.env_name.ends_with('*')
646            && !is_split
647            && ctx.env_name != "aligned"
648            && ctx.env_name != "alignedat";
649
650        let config = ArrayConfig {
651            add_jot: Some(true),
652            empty_single_row: true,
653            col_separation_type: Some(sep_type.to_string()),
654            max_num_cols: if is_split { Some(2) } else { None },
655            auto_number,
656            ..Default::default()
657        };
658
659        let mut res = parse_array(ctx.parser, config, Some(StyleStr::Display))?;
660
661        // Extract explicit column count from first arg (alignat only)
662        let mut num_maths = 0usize;
663        let mut explicit_cols = 0usize;
664        if let Some(ParseNode::OrdGroup { body, .. }) = args.first() {
665            let mut arg_str = String::new();
666            for node in body {
667                if let Some(t) = node.symbol_text() {
668                    arg_str.push_str(t);
669                }
670            }
671            if let Ok(n) = arg_str.parse::<usize>() {
672                num_maths = n;
673                explicit_cols = n * 2;
674            }
675        }
676        let is_aligned = explicit_cols == 0;
677
678        // Determine actual number of columns
679        let mut num_cols = if let ParseNode::Array { ref body, .. } = res {
680            body.iter().map(|r| r.len()).max().unwrap_or(0)
681        } else {
682            0
683        };
684
685        if let ParseNode::Array {
686            body: ref mut array_body,
687            ..
688        } = res
689        {
690            for row in array_body.iter_mut() {
691                // Prepend empty group at every even-indexed cell (2nd, 4th, ...)
692                let mut i = 1;
693                while i < row.len() {
694                    if let ParseNode::Styling {
695                        body: ref mut styling_body,
696                        ..
697                    } = row[i]
698                    {
699                        if let Some(ParseNode::OrdGroup {
700                            body: ref mut og_body,
701                            ..
702                        }) = styling_body.first_mut()
703                        {
704                            og_body.insert(
705                                0,
706                                ParseNode::OrdGroup {
707                                    mode: ctx.mode,
708                                    body: vec![],
709                                    semisimple: None,
710                                    loc: None,
711                                },
712                            );
713                        }
714                    }
715                    i += 2;
716                }
717
718                if !is_aligned {
719                    let cur_maths = row.len() / 2;
720                    if num_maths < cur_maths {
721                        return Err(ParseError::msg(format!(
722                            "Too many math in a row: expected {}, but got {}",
723                            num_maths, cur_maths
724                        )));
725                    }
726                } else if num_cols < row.len() {
727                    num_cols = row.len();
728                }
729            }
730        }
731
732        if !is_aligned {
733            num_cols = explicit_cols;
734        }
735
736        let mut cols = Vec::new();
737        for i in 0..num_cols {
738            let (align, pregap) = if i % 2 == 1 {
739                ("l", 0.0)
740            } else if i > 0 && is_aligned {
741                ("r", 1.0)
742            } else {
743                ("r", 0.0)
744            };
745            cols.push(AlignSpec {
746                align_type: AlignType::Align,
747                align: Some(align.to_string()),
748                pregap: Some(pregap),
749                postgap: Some(0.0),
750            });
751        }
752
753        if let ParseNode::Array {
754            cols: ref mut array_cols,
755            col_separation_type: ref mut array_sep_type,
756            ..
757        } = res
758        {
759            *array_cols = Some(cols);
760            *array_sep_type = Some(
761                if is_aligned { "align" } else { "alignat" }.to_string(),
762            );
763        }
764
765    Ok(res)
766}
767
768fn register_align(map: &mut HashMap<&'static str, EnvSpec>) {
769    for name in &["align", "align*", "aligned", "split"] {
770        map.insert(
771            name,
772            EnvSpec {
773                num_args: 0,
774                num_optional_args: 0,
775                handler: handle_aligned,
776            },
777        );
778    }
779}
780
781// ── gathered / gather / gather* ──────────────────────────────────────────
782
783fn register_gathered(map: &mut HashMap<&'static str, EnvSpec>) {
784    fn handle_gathered(
785        ctx: &mut EnvContext,
786        _args: Vec<ParseNode>,
787        _opt_args: Vec<Option<ParseNode>>,
788    ) -> ParseResult<ParseNode> {
789        let auto_number = !ctx.env_name.ends_with('*') && ctx.env_name != "gathered";
790        let config = ArrayConfig {
791            cols: Some(vec![AlignSpec {
792                align_type: AlignType::Align,
793                align: Some("c".to_string()),
794                pregap: None,
795                postgap: None,
796            }]),
797            add_jot: Some(true),
798            col_separation_type: Some("gather".to_string()),
799            empty_single_row: true,
800            auto_number,
801            ..Default::default()
802        };
803        parse_array(ctx.parser, config, Some(StyleStr::Display))
804    }
805
806    for name in &["gathered", "gather", "gather*"] {
807        map.insert(
808            name,
809            EnvSpec {
810                num_args: 0,
811                num_optional_args: 0,
812                handler: handle_gathered,
813            },
814        );
815    }
816}
817
818// ── equation / equation* ─────────────────────────────────────────────────
819
820fn register_equation(map: &mut HashMap<&'static str, EnvSpec>) {
821    fn handle_equation(
822        ctx: &mut EnvContext,
823        _args: Vec<ParseNode>,
824        _opt_args: Vec<Option<ParseNode>>,
825    ) -> ParseResult<ParseNode> {
826        let auto_number = !ctx.env_name.ends_with('*');
827        let config = ArrayConfig {
828            empty_single_row: true,
829            single_row: true,
830            max_num_cols: Some(1),
831            auto_number,
832            ..Default::default()
833        };
834        parse_array(ctx.parser, config, Some(StyleStr::Display))
835    }
836
837    for name in &["equation", "equation*"] {
838        map.insert(
839            name,
840            EnvSpec {
841                num_args: 0,
842                num_optional_args: 0,
843                handler: handle_equation,
844            },
845        );
846    }
847}
848
849// ── smallmatrix ──────────────────────────────────────────────────────────
850
851fn register_smallmatrix(map: &mut HashMap<&'static str, EnvSpec>) {
852    fn handle_smallmatrix(
853        ctx: &mut EnvContext,
854        _args: Vec<ParseNode>,
855        _opt_args: Vec<Option<ParseNode>>,
856    ) -> ParseResult<ParseNode> {
857        let config = ArrayConfig {
858            arraystretch: Some(0.5),
859            ..Default::default()
860        };
861        let mut res = parse_array(ctx.parser, config, Some(StyleStr::Script))?;
862        if let ParseNode::Array {
863            ref mut col_separation_type,
864            ..
865        } = res
866        {
867            *col_separation_type = Some("small".to_string());
868        }
869        Ok(res)
870    }
871
872    map.insert(
873        "smallmatrix",
874        EnvSpec {
875            num_args: 0,
876            num_optional_args: 0,
877            handler: handle_smallmatrix,
878        },
879    );
880}
881
882// ── alignat / alignat* / alignedat ──────────────────────────────────────
883
884fn register_alignat(map: &mut HashMap<&'static str, EnvSpec>) {
885    for name in &["alignat", "alignat*", "alignedat"] {
886        map.insert(
887            name,
888            EnvSpec {
889                num_args: 1,
890                num_optional_args: 0,
891                handler: handle_aligned,
892            },
893        );
894    }
895}
896
897// ── CD (amscd commutative diagrams) ──────────────────────────────────────
898
899fn register_cd(map: &mut HashMap<&'static str, EnvSpec>) {
900    fn handle_cd(
901        ctx: &mut EnvContext,
902        _args: Vec<ParseNode>,
903        _opt_args: Vec<Option<ParseNode>>,
904    ) -> ParseResult<ParseNode> {
905        // Collect all raw tokens until \end
906        let mut raw: Vec<Token> = Vec::new();
907        loop {
908            let tok = ctx.parser.gullet.future().clone();
909            if tok.text == "\\end" || tok.text == "EOF" {
910                break;
911            }
912            ctx.parser.gullet.pop_token();
913            raw.push(tok);
914        }
915
916        // Split into rows at \\ or \cr
917        let rows = cd_split_rows(raw);
918
919        let mut body: Vec<Vec<ParseNode>> = Vec::new();
920        let mut row_gaps: Vec<Option<Measurement>> = Vec::new();
921        let mut hlines_before_row: Vec<Vec<bool>> = Vec::new();
922        hlines_before_row.push(vec![]);
923
924        for row_toks in rows {
925            // Skip purely-whitespace rows
926            if row_toks.iter().all(|t| t.text == " ") {
927                continue;
928            }
929            let cells = cd_parse_row(ctx.parser, row_toks)?;
930            if !cells.is_empty() {
931                body.push(cells);
932                row_gaps.push(None);
933                hlines_before_row.push(vec![]);
934            }
935        }
936
937        if body.is_empty() {
938            body.push(vec![]);
939            hlines_before_row.push(vec![]);
940        }
941
942        Ok(ParseNode::Array {
943            mode: ctx.mode,
944            body,
945            row_gaps,
946            hlines_before_row,
947            cols: None,
948            col_separation_type: Some("CD".to_string()),
949            hskip_before_and_after: Some(false),
950            add_jot: None,
951            arraystretch: 1.0,
952            tags: None,
953            leqno: None,
954            is_cd: Some(true),
955            loc: None,
956        })
957    }
958
959    map.insert(
960        "CD",
961        EnvSpec {
962            num_args: 0,
963            num_optional_args: 0,
964            handler: handle_cd,
965        },
966    );
967}
968
969/// Split a flat token list into rows at `\\` or `\cr` boundaries.
970fn cd_split_rows(tokens: Vec<Token>) -> Vec<Vec<Token>> {
971    let mut rows: Vec<Vec<Token>> = Vec::new();
972    let mut current: Vec<Token> = Vec::new();
973    for tok in tokens {
974        if tok.text == "\\\\" || tok.text == "\\cr" {
975            rows.push(current);
976            current = Vec::new();
977        } else {
978            current.push(tok);
979        }
980    }
981    if !current.is_empty() {
982        rows.push(current);
983    }
984    rows
985}
986
987/// Collect tokens from `tokens[start..]` up to (but not including) the first
988/// token whose text equals `delimiter`.  Returns (collected_tokens, tokens_consumed).
989/// `tokens_consumed` includes the delimiter itself if found.
990fn cd_collect_until(tokens: &[Token], start: usize, delimiter: &str) -> (Vec<Token>, usize) {
991    let mut result = Vec::new();
992    let mut i = start;
993    while i < tokens.len() {
994        if tokens[i].text == delimiter {
995            i += 1; // consume the delimiter
996            break;
997        }
998        result.push(tokens[i].clone());
999        i += 1;
1000    }
1001    (result, i - start)
1002}
1003
1004/// Collect tokens from `tokens[start..]` up to (but not including) the next `@`.
1005fn cd_collect_until_at(tokens: &[Token], start: usize) -> (Vec<Token>, usize) {
1006    let mut result = Vec::new();
1007    let mut i = start;
1008    while i < tokens.len() && tokens[i].text != "@" {
1009        result.push(tokens[i].clone());
1010        i += 1;
1011    }
1012    (result, i - start)
1013}
1014
1015/// Use the parser to parse a token slice as a math OrdGroup.
1016/// Tokens must be in forward order; this function reverses them internally for subparse().
1017fn cd_parse_tokens(parser: &mut Parser, tokens: Vec<Token>) -> ParseResult<ParseNode> {
1018    // Filter pure whitespace
1019    let has_content = tokens.iter().any(|t| t.text != " ");
1020    if !has_content {
1021        return Ok(ParseNode::OrdGroup {
1022            mode: parser.mode,
1023            body: vec![],
1024            semisimple: None,
1025            loc: None,
1026        });
1027    }
1028    // subparse() expects tokens in reverse order (stack convention)
1029    let mut rev = tokens;
1030    rev.reverse();
1031    let body = parser.subparse(rev)?;
1032    Ok(ParseNode::OrdGroup {
1033        mode: parser.mode,
1034        body,
1035        semisimple: None,
1036        loc: None,
1037    })
1038}
1039
1040/// Parse one row of a CD environment from its raw token list.
1041/// Returns the list of ParseNode cells for the grid row.
1042fn cd_parse_row(parser: &mut Parser, row_tokens: Vec<Token>) -> ParseResult<Vec<ParseNode>> {
1043    let toks = &row_tokens;
1044    let n = toks.len();
1045    let mut cells: Vec<ParseNode> = Vec::new();
1046    let mut i = 0usize;
1047
1048    while i < n {
1049        // Skip spaces at start of each cell
1050        while i < n && toks[i].text == " " {
1051            i += 1;
1052        }
1053        if i >= n {
1054            break;
1055        }
1056
1057        if toks[i].text == "@" {
1058            i += 1; // consume `@`
1059            if i >= n {
1060                return Err(ParseError::msg("Unexpected end of CD row after @"));
1061            }
1062            let dir = toks[i].text.clone();
1063            i += 1; // consume direction char
1064
1065            let mode = parser.mode;
1066            let arrow = match dir.as_str() {
1067                ">" | "<" => {
1068                    let (above_toks, c1) = cd_collect_until(toks, i, &dir);
1069                    i += c1;
1070                    let (below_toks, c2) = cd_collect_until(toks, i, &dir);
1071                    i += c2;
1072                    let label_above = cd_parse_tokens(parser, above_toks)?;
1073                    let label_below = cd_parse_tokens(parser, below_toks)?;
1074                    ParseNode::CdArrow {
1075                        mode,
1076                        direction: if dir == ">" { "right" } else { "left" }.to_string(),
1077                        label_above: Some(Box::new(label_above)),
1078                        label_below: Some(Box::new(label_below)),
1079                        loc: None,
1080                    }
1081                }
1082                "V" | "A" => {
1083                    let (left_toks, c1) = cd_collect_until(toks, i, &dir);
1084                    i += c1;
1085                    let (right_toks, c2) = cd_collect_until(toks, i, &dir);
1086                    i += c2;
1087                    let label_above = cd_parse_tokens(parser, left_toks)?;
1088                    let label_below = cd_parse_tokens(parser, right_toks)?;
1089                    ParseNode::CdArrow {
1090                        mode,
1091                        direction: if dir == "V" { "down" } else { "up" }.to_string(),
1092                        label_above: Some(Box::new(label_above)),
1093                        label_below: Some(Box::new(label_below)),
1094                        loc: None,
1095                    }
1096                }
1097                "=" => ParseNode::CdArrow {
1098                    mode,
1099                    direction: "horiz_eq".to_string(),
1100                    label_above: None,
1101                    label_below: None,
1102                    loc: None,
1103                },
1104                "|" => ParseNode::CdArrow {
1105                    mode,
1106                    direction: "vert_eq".to_string(),
1107                    label_above: None,
1108                    label_below: None,
1109                    loc: None,
1110                },
1111                "." => ParseNode::CdArrow {
1112                    mode,
1113                    direction: "none".to_string(),
1114                    label_above: None,
1115                    label_below: None,
1116                    loc: None,
1117                },
1118                _ => return Err(ParseError::msg(format!("Unknown CD directive: @{}", dir))),
1119            };
1120            cells.push(arrow);
1121        } else {
1122            // Object cell: collect until next `@`
1123            let (obj_toks, consumed) = cd_collect_until_at(toks, i);
1124            i += consumed;
1125            let obj = cd_parse_tokens(parser, obj_toks)?;
1126            cells.push(obj);
1127        }
1128    }
1129
1130    // Post-process: structure cells into the (2n-1) grid pattern.
1131    Ok(cd_structure_row(cells, parser.mode))
1132}
1133
1134/// Given the raw parsed cells of one CD row, produce the correctly-structured grid row.
1135///
1136/// Object rows already alternate: obj, h-arrow, obj, h-arrow, …, obj.
1137/// Arrow rows contain only CdArrow nodes (plus whitespace OrdGroups which we strip),
1138/// and need empty OrdGroup fillers inserted between consecutive arrows.
1139fn cd_structure_row(cells: Vec<ParseNode>, mode: Mode) -> Vec<ParseNode> {
1140    // Detect arrow row: all cells are either CdArrow or empty OrdGroup
1141    let is_arrow_row = cells.iter().all(|c| match c {
1142        ParseNode::CdArrow { .. } => true,
1143        ParseNode::OrdGroup { body, .. } => body.is_empty(),
1144        _ => false,
1145    }) && cells.iter().any(|c| matches!(c, ParseNode::CdArrow { .. }));
1146
1147    if is_arrow_row {
1148        let arrows: Vec<ParseNode> = cells
1149            .into_iter()
1150            .filter(|c| matches!(c, ParseNode::CdArrow { .. }))
1151            .collect();
1152
1153        if arrows.is_empty() {
1154            return vec![];
1155        }
1156
1157        let empty = || ParseNode::OrdGroup {
1158            mode,
1159            body: vec![],
1160            semisimple: None,
1161            loc: None,
1162        };
1163
1164        let mut result = Vec::with_capacity(arrows.len() * 2 - 1);
1165        for (idx, arrow) in arrows.into_iter().enumerate() {
1166            if idx > 0 {
1167                result.push(empty());
1168            }
1169            result.push(arrow);
1170        }
1171        result
1172    } else {
1173        // Object row: already in correct format
1174        cells
1175    }
1176}
1177
1178// ── subarray ────────────────────────────────────────────────────────────
1179
1180fn register_subarray(map: &mut HashMap<&'static str, EnvSpec>) {
1181    fn handle_subarray(
1182        ctx: &mut EnvContext,
1183        args: Vec<ParseNode>,
1184        _opt_args: Vec<Option<ParseNode>>,
1185    ) -> ParseResult<ParseNode> {
1186        let colalign = match &args[0] {
1187            ParseNode::OrdGroup { body, .. } => body.clone(),
1188            other if other.is_symbol_node() => vec![other.clone()],
1189            _ => return Err(ParseError::msg("Invalid column alignment for subarray")),
1190        };
1191
1192        let mut cols = Vec::new();
1193        for nde in &colalign {
1194            let ca = nde
1195                .symbol_text()
1196                .ok_or_else(|| ParseError::msg("Expected column alignment character"))?;
1197            match ca {
1198                "l" | "c" => cols.push(AlignSpec {
1199                    align_type: AlignType::Align,
1200                    align: Some(ca.to_string()),
1201                    pregap: None,
1202                    postgap: None,
1203                }),
1204                _ => {
1205                    return Err(ParseError::msg(format!(
1206                        "Unknown column alignment: {}",
1207                        ca
1208                    )))
1209                }
1210            }
1211        }
1212
1213        if cols.len() > 1 {
1214            return Err(ParseError::msg("{subarray} can contain only one column"));
1215        }
1216
1217        let config = ArrayConfig {
1218            cols: Some(cols),
1219            hskip_before_and_after: Some(false),
1220            arraystretch: Some(0.5),
1221            ..Default::default()
1222        };
1223
1224        let res = parse_array(ctx.parser, config, Some(StyleStr::Script))?;
1225
1226        if let ParseNode::Array { ref body, .. } = res {
1227            if !body.is_empty() && body[0].len() > 1 {
1228                return Err(ParseError::msg("{subarray} can contain only one column"));
1229            }
1230        }
1231
1232        Ok(res)
1233    }
1234
1235    map.insert(
1236        "subarray",
1237        EnvSpec {
1238            num_args: 1,
1239            num_optional_args: 0,
1240            handler: handle_subarray,
1241        },
1242    );
1243}
1244
1245// ── prooftree (bussproofs subset) ───────────────────────────────────────
1246//
1247// Implemented: \AxiomC, \UnaryInfC, \BinaryInfC, \TrinaryInfC, \QuaternaryInfC,
1248//   \QuinaryInfC (and their abbreviations \AXC, \UIC, \BIC, \TIC),
1249//   \LeftLabel/\RightLabel (and \LL, \RL), \fCenter,
1250//   line styles: \solidLine, \dashedLine, \noLine (and \always* variants),
1251//   \rootAtTop, \rootAtBottom (and \always* variants).
1252//
1253// Not yet implemented: \InsertBetweenHyps, \ScoreTree, \Cell, \noCell.
1254
1255fn register_prooftree(map: &mut HashMap<&'static str, EnvSpec>) {
1256    fn handle_prooftree(
1257        ctx: &mut EnvContext,
1258        _args: Vec<ParseNode>,
1259        _opt_args: Vec<Option<ParseNode>>,
1260    ) -> ParseResult<ParseNode> {
1261        parse_prooftree(ctx.parser)
1262    }
1263
1264    map.insert(
1265        "prooftree",
1266        EnvSpec {
1267            num_args: 0,
1268            num_optional_args: 0,
1269            handler: handle_prooftree,
1270        },
1271    );
1272}
1273
1274fn proof_command_arity(name: &str) -> Option<usize> {
1275    match name {
1276        "\\UnaryInfC" | "\\UnaryInf" | "\\UIC" => Some(1),
1277        "\\BinaryInfC" | "\\BinaryInf" | "\\BIC" => Some(2),
1278        "\\TrinaryInfC" | "\\TrinaryInf" | "\\TIC" => Some(3),
1279        "\\QuaternaryInfC" | "\\QuaternaryInf" => Some(4),
1280        "\\QuinaryInfC" | "\\QuinaryInf" => Some(5),
1281        _ => None,
1282    }
1283}
1284
1285fn parse_prooftree_arg(parser: &mut Parser, command: &str) -> ParseResult<Vec<ParseNode>> {
1286    let arg = parser.parse_argument_group(false, None)?.ok_or_else(|| {
1287        ParseError::msg(format!("Expected argument for {}", command))
1288    })?;
1289    Ok(ParseNode::ord_argument(arg))
1290}
1291
1292fn parse_prooftree(parser: &mut Parser) -> ParseResult<ParseNode> {
1293    let mut stack: Vec<ProofBranch> = Vec::new();
1294    let mut left_label: Option<Vec<ParseNode>> = None;
1295    let mut right_label: Option<Vec<ParseNode>> = None;
1296    let mut next_line_style = ProofLineStyle::Solid;
1297    let mut default_line_style = ProofLineStyle::Solid;
1298    let mut next_root_at_top = false;
1299    let mut default_root_at_top = false;
1300
1301    loop {
1302        parser.consume_spaces()?;
1303        let token = parser.fetch()?;
1304        let command = token.text.clone();
1305
1306        if command == "\\end" {
1307            break;
1308        }
1309        parser.consume();
1310
1311        match command.as_str() {
1312            "\\AxiomC" | "\\Axiom" | "\\AXC" => {
1313                let conclusion = parse_prooftree_arg(parser, &command)?;
1314                stack.push(ProofBranch {
1315                    conclusion,
1316                    premises: Vec::new(),
1317                    left_label: None,
1318                    right_label: None,
1319                    line_style: ProofLineStyle::None,
1320                    root_at_top: false,
1321                });
1322            }
1323            "\\LeftLabel" | "\\LL" => {
1324                left_label = Some(parse_prooftree_arg(parser, &command)?);
1325            }
1326            "\\RightLabel" | "\\RL" => {
1327                right_label = Some(parse_prooftree_arg(parser, &command)?);
1328            }
1329            "\\singleLine" | "\\solidLine" => {
1330                next_line_style = ProofLineStyle::Solid;
1331            }
1332            "\\dashedLine" => {
1333                next_line_style = ProofLineStyle::Dashed;
1334            }
1335            "\\noLine" => {
1336                next_line_style = ProofLineStyle::None;
1337            }
1338            "\\alwaysSingleLine" | "\\alwaysSolidLine" => {
1339                default_line_style = ProofLineStyle::Solid;
1340                next_line_style = ProofLineStyle::Solid;
1341            }
1342            "\\alwaysDashedLine" => {
1343                default_line_style = ProofLineStyle::Dashed;
1344                next_line_style = ProofLineStyle::Dashed;
1345            }
1346            "\\alwaysNoLine" => {
1347                default_line_style = ProofLineStyle::None;
1348                next_line_style = ProofLineStyle::None;
1349            }
1350            "\\rootAtTop" => {
1351                next_root_at_top = true;
1352            }
1353            "\\rootAtBottom" => {
1354                next_root_at_top = false;
1355            }
1356            "\\alwaysRootAtTop" => {
1357                default_root_at_top = true;
1358                next_root_at_top = true;
1359            }
1360            "\\alwaysRootAtBottom" => {
1361                default_root_at_top = false;
1362                next_root_at_top = false;
1363            }
1364            name if proof_command_arity(name).is_some() => {
1365                let arity = proof_command_arity(name).unwrap();
1366                if stack.len() < arity {
1367                    return Err(ParseError::msg(format!(
1368                        "{} needs {} premise(s), but only {} available",
1369                        name,
1370                        arity,
1371                        stack.len()
1372                    )));
1373                }
1374                let conclusion = parse_prooftree_arg(parser, name)?;
1375                let start = stack.len() - arity;
1376                let premises = stack.split_off(start);
1377                stack.push(ProofBranch {
1378                    conclusion,
1379                    premises,
1380                    left_label: left_label.take(),
1381                    right_label: right_label.take(),
1382                    line_style: next_line_style.clone(),
1383                    root_at_top: next_root_at_top,
1384                });
1385                next_line_style = default_line_style.clone();
1386                next_root_at_top = default_root_at_top;
1387            }
1388            _ => {
1389                return Err(ParseError::msg(format!(
1390                    "{} valid only as a supported bussproofs command within prooftree",
1391                    command
1392                )));
1393            }
1394        }
1395    }
1396
1397    if stack.len() != 1 {
1398        return Err(ParseError::msg(format!(
1399            "prooftree ended with {} proof stack item(s), expected 1",
1400            stack.len()
1401        )));
1402    }
1403
1404    if left_label.is_some() || right_label.is_some() {
1405        return Err(ParseError::msg(
1406            "prooftree has a \\LeftLabel or \\RightLabel without a following inference command",
1407        ));
1408    }
1409
1410    Ok(ParseNode::ProofTree {
1411        mode: parser.mode,
1412        tree: stack.pop().unwrap(),
1413        loc: None,
1414    })
1415}