Skip to main content

ratex_parser/
environments.rs

1use std::collections::HashMap;
2
3use ratex_lexer::token::Token;
4
5use crate::error::{ParseError, ParseResult};
6use crate::macro_expander::MacroDefinition;
7use crate::parse_node::{AlignSpec, AlignType, ArrayTag, Measurement, Mode, ParseNode, StyleStr};
8use crate::parser::Parser;
9
10// ── Environment registry ─────────────────────────────────────────────────
11
12pub struct EnvContext<'a, 'b> {
13    pub mode: Mode,
14    pub env_name: String,
15    pub parser: &'a mut Parser<'b>,
16}
17
18pub type EnvHandler = fn(
19    ctx: &mut EnvContext,
20    args: Vec<ParseNode>,
21    opt_args: Vec<Option<ParseNode>>,
22) -> ParseResult<ParseNode>;
23
24pub struct EnvSpec {
25    pub num_args: usize,
26    pub num_optional_args: usize,
27    pub handler: EnvHandler,
28}
29
30pub static ENVIRONMENTS: std::sync::LazyLock<HashMap<&'static str, EnvSpec>> =
31    std::sync::LazyLock::new(|| {
32        let mut map = HashMap::new();
33        register_array(&mut map);
34        register_matrix(&mut map);
35        register_cases(&mut map);
36        register_align(&mut map);
37        register_gathered(&mut map);
38        register_equation(&mut map);
39        register_smallmatrix(&mut map);
40        register_alignat(&mut map);
41        register_subarray(&mut map);
42        register_cd(&mut map);
43        map
44    });
45
46// ── ArrayConfig ──────────────────────────────────────────────────────────
47
48#[derive(Default)]
49pub struct ArrayConfig {
50    pub hskip_before_and_after: Option<bool>,
51    pub add_jot: Option<bool>,
52    pub cols: Option<Vec<AlignSpec>>,
53    pub arraystretch: Option<f64>,
54    pub col_separation_type: Option<String>,
55    pub single_row: bool,
56    pub empty_single_row: bool,
57    pub max_num_cols: Option<usize>,
58    pub leqno: Option<bool>,
59    pub auto_number: bool,
60}
61
62
63// ── parseArray ───────────────────────────────────────────────────────────
64
65/// Pull a trailing `\\tag{…}` or `\\nonumber`/`\\notag` off the last cell of a row.
66/// Returns `Auto(true)` when the row is eligible for auto-numbering.
67/// The `auto_number` parameter controls the default when no marker is found.
68fn extract_trailing_tag_from_last_cell(row: &mut [ParseNode], auto_number: bool) -> ParseResult<ArrayTag> {
69    let default_tag = if auto_number { ArrayTag::Auto(true) } else { ArrayTag::Auto(false) };
70    let Some(last) = row.last_mut() else {
71        return Ok(default_tag);
72    };
73
74    let inner: &mut ParseNode = match last {
75        ParseNode::Styling { body, .. } => {
76            if body.len() != 1 {
77                return Ok(default_tag);
78            }
79            &mut body[0]
80        }
81        _ => last,
82    };
83
84    let obody = match inner {
85        ParseNode::OrdGroup { body, .. } => body,
86        _ => return Ok(default_tag),
87    };
88
89    // Look for \\tag
90    let tag_indices: Vec<usize> = obody
91        .iter()
92        .enumerate()
93        .filter(|(_, n)| matches!(n, ParseNode::Tag { .. }))
94        .map(|(i, _)| i)
95        .collect();
96
97    // Look for \\nonumber / \\notag
98    let nonumber_indices: Vec<usize> = obody
99        .iter()
100        .enumerate()
101        .filter(|(_, n)| matches!(n, ParseNode::NoNumber { .. }))
102        .map(|(i, _)| i)
103        .collect();
104
105    // Can't have both \\tag and \\nonumber in the same row
106    if !tag_indices.is_empty() && !nonumber_indices.is_empty() {
107        return Err(ParseError::msg(
108            "Cannot use both \\tag and \\nonumber in the same row",
109        ));
110    }
111
112    // Handle \\tag
113    if !tag_indices.is_empty() {
114        if tag_indices.len() > 1 {
115            return Err(ParseError::msg("Multiple \\tag in a row"));
116        }
117        let idx = tag_indices[0];
118        if idx != obody.len() - 1 {
119            return Err(ParseError::msg(
120                "\\tag must appear at the end of the row after the equation body",
121            ));
122        }
123        match obody.pop() {
124            Some(ParseNode::Tag { tag, .. }) => {
125                if tag.is_empty() {
126                    Ok(ArrayTag::Auto(false))
127                } else {
128                    Ok(ArrayTag::Explicit(tag))
129                }
130            }
131            _ => Ok(default_tag),
132        }
133    } else if !nonumber_indices.is_empty() {
134        // Handle \\nonumber / \\notag
135        if nonumber_indices.len() > 1 {
136            return Err(ParseError::msg("Multiple \\nonumber in a row"));
137        }
138        let idx = nonumber_indices[0];
139        if idx != obody.len() - 1 {
140            return Err(ParseError::msg(
141                "\\nonumber must appear at the end of the row",
142            ));
143        }
144        obody.pop(); // discard the NoNumber node
145        Ok(ArrayTag::Auto(false))
146    } else {
147        // Neither \\tag nor \\nonumber
148        Ok(default_tag)
149    }
150}
151
152fn get_hlines(parser: &mut Parser) -> ParseResult<Vec<bool>> {
153    let mut hline_info = Vec::new();
154    parser.consume_spaces()?;
155
156    let mut nxt = parser.fetch()?.text.clone();
157    if nxt == "\\relax" {
158        parser.consume();
159        parser.consume_spaces()?;
160        nxt = parser.fetch()?.text.clone();
161    }
162    while nxt == "\\hline" || nxt == "\\hdashline" {
163        parser.consume();
164        hline_info.push(nxt == "\\hdashline");
165        parser.consume_spaces()?;
166        nxt = parser.fetch()?.text.clone();
167    }
168    Ok(hline_info)
169}
170
171fn d_cell_style(env_name: &str) -> Option<StyleStr> {
172    if env_name.starts_with('d') {
173        Some(StyleStr::Display)
174    } else {
175        Some(StyleStr::Text)
176    }
177}
178
179pub fn parse_array(
180    parser: &mut Parser,
181    config: ArrayConfig,
182    style: Option<StyleStr>,
183) -> ParseResult<ParseNode> {
184    parser.gullet.begin_group();
185
186    if !config.single_row {
187        parser
188            .gullet
189            .set_text_macro("\\cr", "\\\\\\relax");
190    }
191
192    let arraystretch = config.arraystretch.unwrap_or_else(|| {
193        // Check if \arraystretch is defined as a macro (e.g., via \def\arraystretch{1.5})
194        if let Some(def) = parser.gullet.get_macro("\\arraystretch") {
195            let s = match def {
196                MacroDefinition::Text(s) => s.clone(),
197                MacroDefinition::Tokens { tokens, .. } => {
198                    // Tokens are stored in reverse order (stack convention for expansion)
199                    tokens.iter().rev().map(|t| t.text.as_str()).collect::<String>()
200                }
201                MacroDefinition::Function(_) => String::new(),
202            };
203            s.parse::<f64>().unwrap_or(1.0)
204        } else {
205            1.0
206        }
207    });
208
209    parser.gullet.begin_group();
210
211    let mut row: Vec<ParseNode> = Vec::new();
212    let mut body: Vec<Vec<ParseNode>> = Vec::new();
213    let mut row_tags: Vec<ArrayTag> = Vec::new();
214    let mut row_gaps: Vec<Option<Measurement>> = Vec::new();
215    let mut hlines_before_row: Vec<Vec<bool>> = Vec::new();
216
217    hlines_before_row.push(get_hlines(parser)?);
218
219    loop {
220        let break_token = if config.single_row { "\\end" } else { "\\\\" };
221        let cell_body = parser.parse_expression(false, Some(break_token))?;
222        parser.gullet.end_group();
223        parser.gullet.begin_group();
224
225        let mut cell = ParseNode::OrdGroup {
226            mode: parser.mode,
227            body: cell_body,
228            semisimple: None,
229            loc: None,
230        };
231
232        if let Some(s) = style {
233            cell = ParseNode::Styling {
234                mode: parser.mode,
235                style: s,
236                body: vec![cell],
237                loc: None,
238            };
239        }
240
241        row.push(cell.clone());
242        let next = parser.fetch()?.text.clone();
243
244        if next == "&" {
245            if let Some(max) = config.max_num_cols {
246                if row.len() >= max {
247                    return Err(ParseError::msg("Too many tab characters: &"));
248                }
249            }
250            parser.consume();
251        } else if next == "\\end" {
252            // Check for trailing empty row and remove it
253            let is_empty_trailing = if let Some(s) = style {
254                if s == StyleStr::Text || s == StyleStr::Display {
255                    if let ParseNode::Styling { body: ref sb, .. } = cell {
256                        if let Some(ParseNode::OrdGroup {
257                            body: ref ob, ..
258                        }) = sb.first()
259                        {
260                            ob.is_empty()
261                        } else {
262                            false
263                        }
264                    } else {
265                        false
266                    }
267                } else {
268                    false
269                }
270            } else if let ParseNode::OrdGroup { body: ref ob, .. } = cell {
271                ob.is_empty()
272            } else {
273                false
274            };
275
276            let row_tag = extract_trailing_tag_from_last_cell(&mut row, config.auto_number)?;
277            row_tags.push(row_tag);
278            body.push(row);
279
280            if is_empty_trailing
281                && (body.len() > 1 || !config.empty_single_row)
282            {
283                body.pop();
284                row_tags.pop();
285            }
286
287            if hlines_before_row.len() < body.len() + 1 {
288                hlines_before_row.push(vec![]);
289            }
290            break;
291        } else if next == "\\\\" {
292            parser.consume();
293            let size = if parser.gullet.future().text != " " {
294                parser.parse_size_group(true)?
295            } else {
296                None
297            };
298            let gap = size.and_then(|s| {
299                if let ParseNode::Size { value, .. } = s {
300                    Some(value)
301                } else {
302                    None
303                }
304            });
305            row_gaps.push(gap);
306
307            let row_tag = extract_trailing_tag_from_last_cell(&mut row, config.auto_number)?;
308            row_tags.push(row_tag);
309            body.push(row);
310            hlines_before_row.push(get_hlines(parser)?);
311            row = Vec::new();
312        } else {
313            return Err(ParseError::msg(format!(
314                "Expected & or \\\\ or \\cr or \\end, got '{}'",
315                next
316            )));
317        }
318    }
319
320    parser.gullet.end_group();
321    parser.gullet.end_group();
322
323    // Post-process row tags for auto-numbering
324    let tags = if config.auto_number {
325        let mut processed: Vec<ArrayTag> = Vec::with_capacity(row_tags.len());
326        let mut any_visible = false;
327        for raw_tag in &row_tags {
328            match raw_tag {
329                ArrayTag::Explicit(nodes) if !nodes.is_empty() => {
330                    // Explicit \\tag{...}: step counter, keep tag content as-is
331                    parser.equation_counter += 1;
332                    processed.push(ArrayTag::Explicit(nodes.clone()));
333                    any_visible = true;
334                }
335                ArrayTag::Explicit(_) => {
336                    // Empty \\tag{}: treat as suppressed
337                    processed.push(ArrayTag::Auto(false));
338                }
339                ArrayTag::Auto(true) => {
340                    // Auto-number this row: step counter, generate "(N)"
341                    parser.equation_counter += 1;
342                    let num_str = parser.equation_counter.to_string();
343                    let tag_nodes = vec![
344                        ParseNode::MathOrd {
345                            mode: Mode::Math,
346                            text: "(".to_string(),
347                            loc: None,
348                        },
349                        ParseNode::MathOrd {
350                            mode: Mode::Math,
351                            text: num_str,
352                            loc: None,
353                        },
354                        ParseNode::MathOrd {
355                            mode: Mode::Math,
356                            text: ")".to_string(),
357                            loc: None,
358                        },
359                    ];
360                    processed.push(ArrayTag::Explicit(tag_nodes));
361                    any_visible = true;
362                }
363                ArrayTag::Auto(false) => {
364                    // Suppressed by \\nonumber or empty \\tag{}: no counter step, no tag
365                    processed.push(ArrayTag::Auto(false));
366                }
367            }
368        }
369        if any_visible { Some(processed) } else { None }
370    } else {
371        // Not an auto-numbering environment: keep original behavior
372        if row_tags.iter().any(|t| {
373            matches!(t, ArrayTag::Explicit(nodes) if !nodes.is_empty())
374        }) {
375            Some(row_tags)
376        } else {
377            None
378        }
379    };
380
381    Ok(ParseNode::Array {
382        mode: parser.mode,
383        body,
384        row_gaps,
385        hlines_before_row,
386        cols: config.cols,
387        col_separation_type: config.col_separation_type,
388        hskip_before_and_after: config.hskip_before_and_after,
389        add_jot: config.add_jot,
390        arraystretch,
391        tags,
392        leqno: config.leqno,
393        is_cd: None,
394        loc: None,
395    })
396}
397
398// ── array / darray ───────────────────────────────────────────────────────
399
400fn register_array(map: &mut HashMap<&'static str, EnvSpec>) {
401    fn handle_array(
402        ctx: &mut EnvContext,
403        args: Vec<ParseNode>,
404        _opt_args: Vec<Option<ParseNode>>,
405    ) -> ParseResult<ParseNode> {
406        let colalign = match &args[0] {
407            ParseNode::OrdGroup { body, .. } => body.clone(),
408            other if other.is_symbol_node() => vec![other.clone()],
409            _ => return Err(ParseError::msg("Invalid column alignment for array")),
410        };
411
412        let mut cols = Vec::new();
413        for nde in &colalign {
414            let ca = nde
415                .symbol_text()
416                .ok_or_else(|| ParseError::msg("Expected column alignment character"))?;
417            match ca {
418                "l" | "c" | "r" => cols.push(AlignSpec {
419                    align_type: AlignType::Align,
420                    align: Some(ca.to_string()),
421                    pregap: None,
422                    postgap: None,
423                }),
424                "|" => cols.push(AlignSpec {
425                    align_type: AlignType::Separator,
426                    align: Some("|".to_string()),
427                    pregap: None,
428                    postgap: None,
429                }),
430                ":" => cols.push(AlignSpec {
431                    align_type: AlignType::Separator,
432                    align: Some(":".to_string()),
433                    pregap: None,
434                    postgap: None,
435                }),
436                _ => {
437                    return Err(ParseError::msg(format!(
438                        "Unknown column alignment: {}",
439                        ca
440                    )))
441                }
442            }
443        }
444
445        let max_num_cols = cols.len();
446        let config = ArrayConfig {
447            cols: Some(cols),
448            hskip_before_and_after: Some(true),
449            max_num_cols: Some(max_num_cols),
450            ..Default::default()
451        };
452        parse_array(ctx.parser, config, d_cell_style(&ctx.env_name))
453    }
454
455    for name in &["array", "darray"] {
456        map.insert(
457            name,
458            EnvSpec {
459                num_args: 1,
460                num_optional_args: 0,
461                handler: handle_array,
462            },
463        );
464    }
465}
466
467// ── matrix variants ──────────────────────────────────────────────────────
468
469fn register_matrix(map: &mut HashMap<&'static str, EnvSpec>) {
470    fn handle_matrix(
471        ctx: &mut EnvContext,
472        _args: Vec<ParseNode>,
473        _opt_args: Vec<Option<ParseNode>>,
474    ) -> ParseResult<ParseNode> {
475        let base_name = ctx.env_name.replace('*', "");
476        let delimiters: Option<(&str, &str)> = match base_name.as_str() {
477            "matrix" => None,
478            "pmatrix" => Some(("(", ")")),
479            "bmatrix" => Some(("[", "]")),
480            "Bmatrix" => Some(("\\{", "\\}")),
481            "vmatrix" => Some(("|", "|")),
482            "Vmatrix" => Some(("\\Vert", "\\Vert")),
483            _ => None,
484        };
485
486        let mut col_align = "c".to_string();
487
488        // mathtools starred matrix: parse optional [l|c|r] alignment
489        if ctx.env_name.ends_with('*') {
490            ctx.parser.gullet.consume_spaces();
491            if ctx.parser.gullet.future().text == "[" {
492                ctx.parser.gullet.pop_token();
493                ctx.parser.gullet.consume_spaces();
494                let align_tok = ctx.parser.gullet.pop_token();
495                if !"lcr".contains(align_tok.text.as_str()) {
496                    return Err(ParseError::new(
497                        "Expected l or c or r".to_string(),
498                        Some(&align_tok),
499                    ));
500                }
501                col_align = align_tok.text.clone();
502                ctx.parser.gullet.consume_spaces();
503                let close = ctx.parser.gullet.pop_token();
504                if close.text != "]" {
505                    return Err(ParseError::new(
506                        "Expected ]".to_string(),
507                        Some(&close),
508                    ));
509                }
510            }
511        }
512
513        let config = ArrayConfig {
514            hskip_before_and_after: Some(false),
515            cols: Some(vec![AlignSpec {
516                align_type: AlignType::Align,
517                align: Some(col_align.clone()),
518                pregap: None,
519                postgap: None,
520            }]),
521            ..Default::default()
522        };
523
524        let mut res = parse_array(ctx.parser, config, d_cell_style(&ctx.env_name))?;
525
526        // Fix cols to match actual number of columns
527        if let ParseNode::Array {
528            ref body,
529            ref mut cols,
530            ..
531        } = res
532        {
533            let num_cols = body.iter().map(|r| r.len()).max().unwrap_or(0);
534            *cols = Some(
535                (0..num_cols)
536                    .map(|_| AlignSpec {
537                        align_type: AlignType::Align,
538                        align: Some(col_align.to_string()),
539                        pregap: None,
540                        postgap: None,
541                    })
542                    .collect(),
543            );
544        }
545
546        match delimiters {
547            Some((left, right)) => Ok(ParseNode::LeftRight {
548                mode: ctx.mode,
549                body: vec![res],
550                left: left.to_string(),
551                right: right.to_string(),
552                right_color: None,
553                loc: None,
554            }),
555            None => Ok(res),
556        }
557    }
558
559    for name in &[
560        "matrix", "pmatrix", "bmatrix", "Bmatrix", "vmatrix", "Vmatrix",
561        "matrix*", "pmatrix*", "bmatrix*", "Bmatrix*", "vmatrix*", "Vmatrix*",
562    ] {
563        map.insert(
564            name,
565            EnvSpec {
566                num_args: 0,
567                num_optional_args: 0,
568                handler: handle_matrix,
569            },
570        );
571    }
572}
573
574// ── cases / dcases / rcases / drcases ────────────────────────────────────
575
576fn register_cases(map: &mut HashMap<&'static str, EnvSpec>) {
577    fn handle_cases(
578        ctx: &mut EnvContext,
579        _args: Vec<ParseNode>,
580        _opt_args: Vec<Option<ParseNode>>,
581    ) -> ParseResult<ParseNode> {
582        let config = ArrayConfig {
583            arraystretch: Some(1.2),
584            cols: Some(vec![
585                AlignSpec {
586                    align_type: AlignType::Align,
587                    align: Some("l".to_string()),
588                    pregap: Some(0.0),
589                    postgap: Some(1.0),
590                },
591                AlignSpec {
592                    align_type: AlignType::Align,
593                    align: Some("l".to_string()),
594                    pregap: Some(0.0),
595                    postgap: Some(0.0),
596                },
597            ]),
598            ..Default::default()
599        };
600
601        let res = parse_array(ctx.parser, config, d_cell_style(&ctx.env_name))?;
602
603        let (left, right) = if ctx.env_name.contains('r') {
604            (".", "\\}")
605        } else {
606            ("\\{", ".")
607        };
608
609        Ok(ParseNode::LeftRight {
610            mode: ctx.mode,
611            body: vec![res],
612            left: left.to_string(),
613            right: right.to_string(),
614            right_color: None,
615            loc: None,
616        })
617    }
618
619    for name in &["cases", "dcases", "rcases", "drcases"] {
620        map.insert(
621            name,
622            EnvSpec {
623                num_args: 0,
624                num_optional_args: 0,
625                handler: handle_cases,
626            },
627        );
628    }
629}
630
631// ── align / align* / aligned / split / alignat / alignat* / alignedat ────
632
633fn handle_aligned(
634    ctx: &mut EnvContext,
635    args: Vec<ParseNode>,
636    _opt_args: Vec<Option<ParseNode>>,
637) -> ParseResult<ParseNode> {
638        let is_split = ctx.env_name == "split";
639        let is_alignat = ctx.env_name.contains("at");
640        let sep_type = if is_alignat { "alignat" } else { "align" };
641        let auto_number = !ctx.env_name.ends_with('*')
642            && !is_split
643            && ctx.env_name != "aligned"
644            && ctx.env_name != "alignedat";
645
646        let config = ArrayConfig {
647            add_jot: Some(true),
648            empty_single_row: true,
649            col_separation_type: Some(sep_type.to_string()),
650            max_num_cols: if is_split { Some(2) } else { None },
651            auto_number,
652            ..Default::default()
653        };
654
655        let mut res = parse_array(ctx.parser, config, Some(StyleStr::Display))?;
656
657        // Extract explicit column count from first arg (alignat only)
658        let mut num_maths = 0usize;
659        let mut explicit_cols = 0usize;
660        if let Some(ParseNode::OrdGroup { body, .. }) = args.first() {
661            let mut arg_str = String::new();
662            for node in body {
663                if let Some(t) = node.symbol_text() {
664                    arg_str.push_str(t);
665                }
666            }
667            if let Ok(n) = arg_str.parse::<usize>() {
668                num_maths = n;
669                explicit_cols = n * 2;
670            }
671        }
672        let is_aligned = explicit_cols == 0;
673
674        // Determine actual number of columns
675        let mut num_cols = if let ParseNode::Array { ref body, .. } = res {
676            body.iter().map(|r| r.len()).max().unwrap_or(0)
677        } else {
678            0
679        };
680
681        if let ParseNode::Array {
682            body: ref mut array_body,
683            ..
684        } = res
685        {
686            for row in array_body.iter_mut() {
687                // Prepend empty group at every even-indexed cell (2nd, 4th, ...)
688                let mut i = 1;
689                while i < row.len() {
690                    if let ParseNode::Styling {
691                        body: ref mut styling_body,
692                        ..
693                    } = row[i]
694                    {
695                        if let Some(ParseNode::OrdGroup {
696                            body: ref mut og_body,
697                            ..
698                        }) = styling_body.first_mut()
699                        {
700                            og_body.insert(
701                                0,
702                                ParseNode::OrdGroup {
703                                    mode: ctx.mode,
704                                    body: vec![],
705                                    semisimple: None,
706                                    loc: None,
707                                },
708                            );
709                        }
710                    }
711                    i += 2;
712                }
713
714                if !is_aligned {
715                    let cur_maths = row.len() / 2;
716                    if num_maths < cur_maths {
717                        return Err(ParseError::msg(format!(
718                            "Too many math in a row: expected {}, but got {}",
719                            num_maths, cur_maths
720                        )));
721                    }
722                } else if num_cols < row.len() {
723                    num_cols = row.len();
724                }
725            }
726        }
727
728        if !is_aligned {
729            num_cols = explicit_cols;
730        }
731
732        let mut cols = Vec::new();
733        for i in 0..num_cols {
734            let (align, pregap) = if i % 2 == 1 {
735                ("l", 0.0)
736            } else if i > 0 && is_aligned {
737                ("r", 1.0)
738            } else {
739                ("r", 0.0)
740            };
741            cols.push(AlignSpec {
742                align_type: AlignType::Align,
743                align: Some(align.to_string()),
744                pregap: Some(pregap),
745                postgap: Some(0.0),
746            });
747        }
748
749        if let ParseNode::Array {
750            cols: ref mut array_cols,
751            col_separation_type: ref mut array_sep_type,
752            ..
753        } = res
754        {
755            *array_cols = Some(cols);
756            *array_sep_type = Some(
757                if is_aligned { "align" } else { "alignat" }.to_string(),
758            );
759        }
760
761    Ok(res)
762}
763
764fn register_align(map: &mut HashMap<&'static str, EnvSpec>) {
765    for name in &["align", "align*", "aligned", "split"] {
766        map.insert(
767            name,
768            EnvSpec {
769                num_args: 0,
770                num_optional_args: 0,
771                handler: handle_aligned,
772            },
773        );
774    }
775}
776
777// ── gathered / gather / gather* ──────────────────────────────────────────
778
779fn register_gathered(map: &mut HashMap<&'static str, EnvSpec>) {
780    fn handle_gathered(
781        ctx: &mut EnvContext,
782        _args: Vec<ParseNode>,
783        _opt_args: Vec<Option<ParseNode>>,
784    ) -> ParseResult<ParseNode> {
785        let auto_number = !ctx.env_name.ends_with('*') && ctx.env_name != "gathered";
786        let config = ArrayConfig {
787            cols: Some(vec![AlignSpec {
788                align_type: AlignType::Align,
789                align: Some("c".to_string()),
790                pregap: None,
791                postgap: None,
792            }]),
793            add_jot: Some(true),
794            col_separation_type: Some("gather".to_string()),
795            empty_single_row: true,
796            auto_number,
797            ..Default::default()
798        };
799        parse_array(ctx.parser, config, Some(StyleStr::Display))
800    }
801
802    for name in &["gathered", "gather", "gather*"] {
803        map.insert(
804            name,
805            EnvSpec {
806                num_args: 0,
807                num_optional_args: 0,
808                handler: handle_gathered,
809            },
810        );
811    }
812}
813
814// ── equation / equation* ─────────────────────────────────────────────────
815
816fn register_equation(map: &mut HashMap<&'static str, EnvSpec>) {
817    fn handle_equation(
818        ctx: &mut EnvContext,
819        _args: Vec<ParseNode>,
820        _opt_args: Vec<Option<ParseNode>>,
821    ) -> ParseResult<ParseNode> {
822        let auto_number = !ctx.env_name.ends_with('*');
823        let config = ArrayConfig {
824            empty_single_row: true,
825            single_row: true,
826            max_num_cols: Some(1),
827            auto_number,
828            ..Default::default()
829        };
830        parse_array(ctx.parser, config, Some(StyleStr::Display))
831    }
832
833    for name in &["equation", "equation*"] {
834        map.insert(
835            name,
836            EnvSpec {
837                num_args: 0,
838                num_optional_args: 0,
839                handler: handle_equation,
840            },
841        );
842    }
843}
844
845// ── smallmatrix ──────────────────────────────────────────────────────────
846
847fn register_smallmatrix(map: &mut HashMap<&'static str, EnvSpec>) {
848    fn handle_smallmatrix(
849        ctx: &mut EnvContext,
850        _args: Vec<ParseNode>,
851        _opt_args: Vec<Option<ParseNode>>,
852    ) -> ParseResult<ParseNode> {
853        let config = ArrayConfig {
854            arraystretch: Some(0.5),
855            ..Default::default()
856        };
857        let mut res = parse_array(ctx.parser, config, Some(StyleStr::Script))?;
858        if let ParseNode::Array {
859            ref mut col_separation_type,
860            ..
861        } = res
862        {
863            *col_separation_type = Some("small".to_string());
864        }
865        Ok(res)
866    }
867
868    map.insert(
869        "smallmatrix",
870        EnvSpec {
871            num_args: 0,
872            num_optional_args: 0,
873            handler: handle_smallmatrix,
874        },
875    );
876}
877
878// ── alignat / alignat* / alignedat ──────────────────────────────────────
879
880fn register_alignat(map: &mut HashMap<&'static str, EnvSpec>) {
881    for name in &["alignat", "alignat*", "alignedat"] {
882        map.insert(
883            name,
884            EnvSpec {
885                num_args: 1,
886                num_optional_args: 0,
887                handler: handle_aligned,
888            },
889        );
890    }
891}
892
893// ── CD (amscd commutative diagrams) ──────────────────────────────────────
894
895fn register_cd(map: &mut HashMap<&'static str, EnvSpec>) {
896    fn handle_cd(
897        ctx: &mut EnvContext,
898        _args: Vec<ParseNode>,
899        _opt_args: Vec<Option<ParseNode>>,
900    ) -> ParseResult<ParseNode> {
901        // Collect all raw tokens until \end
902        let mut raw: Vec<Token> = Vec::new();
903        loop {
904            let tok = ctx.parser.gullet.future().clone();
905            if tok.text == "\\end" || tok.text == "EOF" {
906                break;
907            }
908            ctx.parser.gullet.pop_token();
909            raw.push(tok);
910        }
911
912        // Split into rows at \\ or \cr
913        let rows = cd_split_rows(raw);
914
915        let mut body: Vec<Vec<ParseNode>> = Vec::new();
916        let mut row_gaps: Vec<Option<Measurement>> = Vec::new();
917        let mut hlines_before_row: Vec<Vec<bool>> = Vec::new();
918        hlines_before_row.push(vec![]);
919
920        for row_toks in rows {
921            // Skip purely-whitespace rows
922            if row_toks.iter().all(|t| t.text == " ") {
923                continue;
924            }
925            let cells = cd_parse_row(ctx.parser, row_toks)?;
926            if !cells.is_empty() {
927                body.push(cells);
928                row_gaps.push(None);
929                hlines_before_row.push(vec![]);
930            }
931        }
932
933        if body.is_empty() {
934            body.push(vec![]);
935            hlines_before_row.push(vec![]);
936        }
937
938        Ok(ParseNode::Array {
939            mode: ctx.mode,
940            body,
941            row_gaps,
942            hlines_before_row,
943            cols: None,
944            col_separation_type: Some("CD".to_string()),
945            hskip_before_and_after: Some(false),
946            add_jot: None,
947            arraystretch: 1.0,
948            tags: None,
949            leqno: None,
950            is_cd: Some(true),
951            loc: None,
952        })
953    }
954
955    map.insert(
956        "CD",
957        EnvSpec {
958            num_args: 0,
959            num_optional_args: 0,
960            handler: handle_cd,
961        },
962    );
963}
964
965/// Split a flat token list into rows at `\\` or `\cr` boundaries.
966fn cd_split_rows(tokens: Vec<Token>) -> Vec<Vec<Token>> {
967    let mut rows: Vec<Vec<Token>> = Vec::new();
968    let mut current: Vec<Token> = Vec::new();
969    for tok in tokens {
970        if tok.text == "\\\\" || tok.text == "\\cr" {
971            rows.push(current);
972            current = Vec::new();
973        } else {
974            current.push(tok);
975        }
976    }
977    if !current.is_empty() {
978        rows.push(current);
979    }
980    rows
981}
982
983/// Collect tokens from `tokens[start..]` up to (but not including) the first
984/// token whose text equals `delimiter`.  Returns (collected_tokens, tokens_consumed).
985/// `tokens_consumed` includes the delimiter itself if found.
986fn cd_collect_until(tokens: &[Token], start: usize, delimiter: &str) -> (Vec<Token>, usize) {
987    let mut result = Vec::new();
988    let mut i = start;
989    while i < tokens.len() {
990        if tokens[i].text == delimiter {
991            i += 1; // consume the delimiter
992            break;
993        }
994        result.push(tokens[i].clone());
995        i += 1;
996    }
997    (result, i - start)
998}
999
1000/// Collect tokens from `tokens[start..]` up to (but not including) the next `@`.
1001fn cd_collect_until_at(tokens: &[Token], start: usize) -> (Vec<Token>, usize) {
1002    let mut result = Vec::new();
1003    let mut i = start;
1004    while i < tokens.len() && tokens[i].text != "@" {
1005        result.push(tokens[i].clone());
1006        i += 1;
1007    }
1008    (result, i - start)
1009}
1010
1011/// Use the parser to parse a token slice as a math OrdGroup.
1012/// Tokens must be in forward order; this function reverses them internally for subparse().
1013fn cd_parse_tokens(parser: &mut Parser, tokens: Vec<Token>) -> ParseResult<ParseNode> {
1014    // Filter pure whitespace
1015    let has_content = tokens.iter().any(|t| t.text != " ");
1016    if !has_content {
1017        return Ok(ParseNode::OrdGroup {
1018            mode: parser.mode,
1019            body: vec![],
1020            semisimple: None,
1021            loc: None,
1022        });
1023    }
1024    // subparse() expects tokens in reverse order (stack convention)
1025    let mut rev = tokens;
1026    rev.reverse();
1027    let body = parser.subparse(rev)?;
1028    Ok(ParseNode::OrdGroup {
1029        mode: parser.mode,
1030        body,
1031        semisimple: None,
1032        loc: None,
1033    })
1034}
1035
1036/// Parse one row of a CD environment from its raw token list.
1037/// Returns the list of ParseNode cells for the grid row.
1038fn cd_parse_row(parser: &mut Parser, row_tokens: Vec<Token>) -> ParseResult<Vec<ParseNode>> {
1039    let toks = &row_tokens;
1040    let n = toks.len();
1041    let mut cells: Vec<ParseNode> = Vec::new();
1042    let mut i = 0usize;
1043
1044    while i < n {
1045        // Skip spaces at start of each cell
1046        while i < n && toks[i].text == " " {
1047            i += 1;
1048        }
1049        if i >= n {
1050            break;
1051        }
1052
1053        if toks[i].text == "@" {
1054            i += 1; // consume `@`
1055            if i >= n {
1056                return Err(ParseError::msg("Unexpected end of CD row after @"));
1057            }
1058            let dir = toks[i].text.clone();
1059            i += 1; // consume direction char
1060
1061            let mode = parser.mode;
1062            let arrow = match dir.as_str() {
1063                ">" | "<" => {
1064                    let (above_toks, c1) = cd_collect_until(toks, i, &dir);
1065                    i += c1;
1066                    let (below_toks, c2) = cd_collect_until(toks, i, &dir);
1067                    i += c2;
1068                    let label_above = cd_parse_tokens(parser, above_toks)?;
1069                    let label_below = cd_parse_tokens(parser, below_toks)?;
1070                    ParseNode::CdArrow {
1071                        mode,
1072                        direction: if dir == ">" { "right" } else { "left" }.to_string(),
1073                        label_above: Some(Box::new(label_above)),
1074                        label_below: Some(Box::new(label_below)),
1075                        loc: None,
1076                    }
1077                }
1078                "V" | "A" => {
1079                    let (left_toks, c1) = cd_collect_until(toks, i, &dir);
1080                    i += c1;
1081                    let (right_toks, c2) = cd_collect_until(toks, i, &dir);
1082                    i += c2;
1083                    let label_above = cd_parse_tokens(parser, left_toks)?;
1084                    let label_below = cd_parse_tokens(parser, right_toks)?;
1085                    ParseNode::CdArrow {
1086                        mode,
1087                        direction: if dir == "V" { "down" } else { "up" }.to_string(),
1088                        label_above: Some(Box::new(label_above)),
1089                        label_below: Some(Box::new(label_below)),
1090                        loc: None,
1091                    }
1092                }
1093                "=" => ParseNode::CdArrow {
1094                    mode,
1095                    direction: "horiz_eq".to_string(),
1096                    label_above: None,
1097                    label_below: None,
1098                    loc: None,
1099                },
1100                "|" => ParseNode::CdArrow {
1101                    mode,
1102                    direction: "vert_eq".to_string(),
1103                    label_above: None,
1104                    label_below: None,
1105                    loc: None,
1106                },
1107                "." => ParseNode::CdArrow {
1108                    mode,
1109                    direction: "none".to_string(),
1110                    label_above: None,
1111                    label_below: None,
1112                    loc: None,
1113                },
1114                _ => return Err(ParseError::msg(format!("Unknown CD directive: @{}", dir))),
1115            };
1116            cells.push(arrow);
1117        } else {
1118            // Object cell: collect until next `@`
1119            let (obj_toks, consumed) = cd_collect_until_at(toks, i);
1120            i += consumed;
1121            let obj = cd_parse_tokens(parser, obj_toks)?;
1122            cells.push(obj);
1123        }
1124    }
1125
1126    // Post-process: structure cells into the (2n-1) grid pattern.
1127    Ok(cd_structure_row(cells, parser.mode))
1128}
1129
1130/// Given the raw parsed cells of one CD row, produce the correctly-structured grid row.
1131///
1132/// Object rows already alternate: obj, h-arrow, obj, h-arrow, …, obj.
1133/// Arrow rows contain only CdArrow nodes (plus whitespace OrdGroups which we strip),
1134/// and need empty OrdGroup fillers inserted between consecutive arrows.
1135fn cd_structure_row(cells: Vec<ParseNode>, mode: Mode) -> Vec<ParseNode> {
1136    // Detect arrow row: all cells are either CdArrow or empty OrdGroup
1137    let is_arrow_row = cells.iter().all(|c| match c {
1138        ParseNode::CdArrow { .. } => true,
1139        ParseNode::OrdGroup { body, .. } => body.is_empty(),
1140        _ => false,
1141    }) && cells.iter().any(|c| matches!(c, ParseNode::CdArrow { .. }));
1142
1143    if is_arrow_row {
1144        let arrows: Vec<ParseNode> = cells
1145            .into_iter()
1146            .filter(|c| matches!(c, ParseNode::CdArrow { .. }))
1147            .collect();
1148
1149        if arrows.is_empty() {
1150            return vec![];
1151        }
1152
1153        let empty = || ParseNode::OrdGroup {
1154            mode,
1155            body: vec![],
1156            semisimple: None,
1157            loc: None,
1158        };
1159
1160        let mut result = Vec::with_capacity(arrows.len() * 2 - 1);
1161        for (idx, arrow) in arrows.into_iter().enumerate() {
1162            if idx > 0 {
1163                result.push(empty());
1164            }
1165            result.push(arrow);
1166        }
1167        result
1168    } else {
1169        // Object row: already in correct format
1170        cells
1171    }
1172}
1173
1174// ── subarray ────────────────────────────────────────────────────────────
1175
1176fn register_subarray(map: &mut HashMap<&'static str, EnvSpec>) {
1177    fn handle_subarray(
1178        ctx: &mut EnvContext,
1179        args: Vec<ParseNode>,
1180        _opt_args: Vec<Option<ParseNode>>,
1181    ) -> ParseResult<ParseNode> {
1182        let colalign = match &args[0] {
1183            ParseNode::OrdGroup { body, .. } => body.clone(),
1184            other if other.is_symbol_node() => vec![other.clone()],
1185            _ => return Err(ParseError::msg("Invalid column alignment for subarray")),
1186        };
1187
1188        let mut cols = Vec::new();
1189        for nde in &colalign {
1190            let ca = nde
1191                .symbol_text()
1192                .ok_or_else(|| ParseError::msg("Expected column alignment character"))?;
1193            match ca {
1194                "l" | "c" => cols.push(AlignSpec {
1195                    align_type: AlignType::Align,
1196                    align: Some(ca.to_string()),
1197                    pregap: None,
1198                    postgap: None,
1199                }),
1200                _ => {
1201                    return Err(ParseError::msg(format!(
1202                        "Unknown column alignment: {}",
1203                        ca
1204                    )))
1205                }
1206            }
1207        }
1208
1209        if cols.len() > 1 {
1210            return Err(ParseError::msg("{subarray} can contain only one column"));
1211        }
1212
1213        let config = ArrayConfig {
1214            cols: Some(cols),
1215            hskip_before_and_after: Some(false),
1216            arraystretch: Some(0.5),
1217            ..Default::default()
1218        };
1219
1220        let res = parse_array(ctx.parser, config, Some(StyleStr::Script))?;
1221
1222        if let ParseNode::Array { ref body, .. } = res {
1223            if !body.is_empty() && body[0].len() > 1 {
1224                return Err(ParseError::msg("{subarray} can contain only one column"));
1225            }
1226        }
1227
1228        Ok(res)
1229    }
1230
1231    map.insert(
1232        "subarray",
1233        EnvSpec {
1234            num_args: 1,
1235            num_optional_args: 0,
1236            handler: handle_subarray,
1237        },
1238    );
1239}