Skip to main content

ratex_parser/
environments.rs

1use std::collections::HashMap;
2
3use ratex_lexer::token::Token;
4
5use crate::error::{ParseError, ParseResult};
6use crate::macro_expander::MacroDefinition;
7use crate::parse_node::{AlignSpec, AlignType, ArrayTag, Measurement, Mode, ParseNode, StyleStr};
8use crate::parser::Parser;
9
10// ── Environment registry ─────────────────────────────────────────────────
11
12pub struct EnvContext<'a, 'b> {
13    pub mode: Mode,
14    pub env_name: String,
15    pub parser: &'a mut Parser<'b>,
16}
17
18pub type EnvHandler = fn(
19    ctx: &mut EnvContext,
20    args: Vec<ParseNode>,
21    opt_args: Vec<Option<ParseNode>>,
22) -> ParseResult<ParseNode>;
23
24pub struct EnvSpec {
25    pub num_args: usize,
26    pub num_optional_args: usize,
27    pub handler: EnvHandler,
28}
29
30pub static ENVIRONMENTS: std::sync::LazyLock<HashMap<&'static str, EnvSpec>> =
31    std::sync::LazyLock::new(|| {
32        let mut map = HashMap::new();
33        register_array(&mut map);
34        register_matrix(&mut map);
35        register_cases(&mut map);
36        register_align(&mut map);
37        register_gathered(&mut map);
38        register_equation(&mut map);
39        register_smallmatrix(&mut map);
40        register_alignat(&mut map);
41        register_subarray(&mut map);
42        register_cd(&mut map);
43        map
44    });
45
46// ── ArrayConfig ──────────────────────────────────────────────────────────
47
48#[derive(Default)]
49pub struct ArrayConfig {
50    pub hskip_before_and_after: Option<bool>,
51    pub add_jot: Option<bool>,
52    pub cols: Option<Vec<AlignSpec>>,
53    pub arraystretch: Option<f64>,
54    pub col_separation_type: Option<String>,
55    pub single_row: bool,
56    pub empty_single_row: bool,
57    pub max_num_cols: Option<usize>,
58    pub leqno: Option<bool>,
59}
60
61
62// ── parseArray ───────────────────────────────────────────────────────────
63
64/// Pull a trailing `\\tag{…}` off the last cell of a row (amsmath semantics).
65fn extract_trailing_tag_from_last_cell(row: &mut [ParseNode]) -> ParseResult<ArrayTag> {
66    let Some(last) = row.last_mut() else {
67        return Ok(ArrayTag::Auto(false));
68    };
69
70    let inner: &mut ParseNode = match last {
71        ParseNode::Styling { body, .. } => {
72            if body.len() != 1 {
73                return Ok(ArrayTag::Auto(false));
74            }
75            &mut body[0]
76        }
77        _ => last,
78    };
79
80    let obody = match inner {
81        ParseNode::OrdGroup { body, .. } => body,
82        _ => return Ok(ArrayTag::Auto(false)),
83    };
84
85    let tag_indices: Vec<usize> = obody
86        .iter()
87        .enumerate()
88        .filter(|(_, n)| matches!(n, ParseNode::Tag { .. }))
89        .map(|(i, _)| i)
90        .collect();
91
92    if tag_indices.is_empty() {
93        return Ok(ArrayTag::Auto(false));
94    }
95    if tag_indices.len() > 1 {
96        return Err(ParseError::msg("Multiple \\tag in a row"));
97    }
98    let idx = tag_indices[0];
99    if idx != obody.len() - 1 {
100        return Err(ParseError::msg(
101            "\\tag must appear at the end of the row after the equation body",
102        ));
103    }
104
105    match obody.pop() {
106        Some(ParseNode::Tag { tag, .. }) => {
107            if tag.is_empty() {
108                Ok(ArrayTag::Auto(false))
109            } else {
110                Ok(ArrayTag::Explicit(tag))
111            }
112        }
113        _ => Ok(ArrayTag::Auto(false)),
114    }
115}
116
117fn get_hlines(parser: &mut Parser) -> ParseResult<Vec<bool>> {
118    let mut hline_info = Vec::new();
119    parser.consume_spaces()?;
120
121    let mut nxt = parser.fetch()?.text.clone();
122    if nxt == "\\relax" {
123        parser.consume();
124        parser.consume_spaces()?;
125        nxt = parser.fetch()?.text.clone();
126    }
127    while nxt == "\\hline" || nxt == "\\hdashline" {
128        parser.consume();
129        hline_info.push(nxt == "\\hdashline");
130        parser.consume_spaces()?;
131        nxt = parser.fetch()?.text.clone();
132    }
133    Ok(hline_info)
134}
135
136fn d_cell_style(env_name: &str) -> Option<StyleStr> {
137    if env_name.starts_with('d') {
138        Some(StyleStr::Display)
139    } else {
140        Some(StyleStr::Text)
141    }
142}
143
144pub fn parse_array(
145    parser: &mut Parser,
146    config: ArrayConfig,
147    style: Option<StyleStr>,
148) -> ParseResult<ParseNode> {
149    parser.gullet.begin_group();
150
151    if !config.single_row {
152        parser
153            .gullet
154            .set_text_macro("\\cr", "\\\\\\relax");
155    }
156
157    let arraystretch = config.arraystretch.unwrap_or_else(|| {
158        // Check if \arraystretch is defined as a macro (e.g., via \def\arraystretch{1.5})
159        if let Some(def) = parser.gullet.get_macro("\\arraystretch") {
160            let s = match def {
161                MacroDefinition::Text(s) => s.clone(),
162                MacroDefinition::Tokens { tokens, .. } => {
163                    // Tokens are stored in reverse order (stack convention for expansion)
164                    tokens.iter().rev().map(|t| t.text.as_str()).collect::<String>()
165                }
166                MacroDefinition::Function(_) => String::new(),
167            };
168            s.parse::<f64>().unwrap_or(1.0)
169        } else {
170            1.0
171        }
172    });
173
174    parser.gullet.begin_group();
175
176    let mut row: Vec<ParseNode> = Vec::new();
177    let mut body: Vec<Vec<ParseNode>> = Vec::new();
178    let mut row_tags: Vec<ArrayTag> = Vec::new();
179    let mut row_gaps: Vec<Option<Measurement>> = Vec::new();
180    let mut hlines_before_row: Vec<Vec<bool>> = Vec::new();
181
182    hlines_before_row.push(get_hlines(parser)?);
183
184    loop {
185        let break_token = if config.single_row { "\\end" } else { "\\\\" };
186        let cell_body = parser.parse_expression(false, Some(break_token))?;
187        parser.gullet.end_group();
188        parser.gullet.begin_group();
189
190        let mut cell = ParseNode::OrdGroup {
191            mode: parser.mode,
192            body: cell_body,
193            semisimple: None,
194            loc: None,
195        };
196
197        if let Some(s) = style {
198            cell = ParseNode::Styling {
199                mode: parser.mode,
200                style: s,
201                body: vec![cell],
202                loc: None,
203            };
204        }
205
206        row.push(cell.clone());
207        let next = parser.fetch()?.text.clone();
208
209        if next == "&" {
210            if let Some(max) = config.max_num_cols {
211                if row.len() >= max {
212                    return Err(ParseError::msg("Too many tab characters: &"));
213                }
214            }
215            parser.consume();
216        } else if next == "\\end" {
217            // Check for trailing empty row and remove it
218            let is_empty_trailing = if let Some(s) = style {
219                if s == StyleStr::Text || s == StyleStr::Display {
220                    if let ParseNode::Styling { body: ref sb, .. } = cell {
221                        if let Some(ParseNode::OrdGroup {
222                            body: ref ob, ..
223                        }) = sb.first()
224                        {
225                            ob.is_empty()
226                        } else {
227                            false
228                        }
229                    } else {
230                        false
231                    }
232                } else {
233                    false
234                }
235            } else if let ParseNode::OrdGroup { body: ref ob, .. } = cell {
236                ob.is_empty()
237            } else {
238                false
239            };
240
241            let row_tag = extract_trailing_tag_from_last_cell(&mut row)?;
242            row_tags.push(row_tag);
243            body.push(row);
244
245            if is_empty_trailing
246                && (body.len() > 1 || !config.empty_single_row)
247            {
248                body.pop();
249                row_tags.pop();
250            }
251
252            if hlines_before_row.len() < body.len() + 1 {
253                hlines_before_row.push(vec![]);
254            }
255            break;
256        } else if next == "\\\\" {
257            parser.consume();
258            let size = if parser.gullet.future().text != " " {
259                parser.parse_size_group(true)?
260            } else {
261                None
262            };
263            let gap = size.and_then(|s| {
264                if let ParseNode::Size { value, .. } = s {
265                    Some(value)
266                } else {
267                    None
268                }
269            });
270            row_gaps.push(gap);
271
272            let row_tag = extract_trailing_tag_from_last_cell(&mut row)?;
273            row_tags.push(row_tag);
274            body.push(row);
275            hlines_before_row.push(get_hlines(parser)?);
276            row = Vec::new();
277        } else {
278            return Err(ParseError::msg(format!(
279                "Expected & or \\\\ or \\cr or \\end, got '{}'",
280                next
281            )));
282        }
283    }
284
285    parser.gullet.end_group();
286    parser.gullet.end_group();
287
288    let tags = if row_tags.iter().any(|t| {
289        matches!(t, ArrayTag::Explicit(nodes) if !nodes.is_empty())
290    }) {
291        Some(row_tags)
292    } else {
293        None
294    };
295
296    Ok(ParseNode::Array {
297        mode: parser.mode,
298        body,
299        row_gaps,
300        hlines_before_row,
301        cols: config.cols,
302        col_separation_type: config.col_separation_type,
303        hskip_before_and_after: config.hskip_before_and_after,
304        add_jot: config.add_jot,
305        arraystretch,
306        tags,
307        leqno: config.leqno,
308        is_cd: None,
309        loc: None,
310    })
311}
312
313// ── array / darray ───────────────────────────────────────────────────────
314
315fn register_array(map: &mut HashMap<&'static str, EnvSpec>) {
316    fn handle_array(
317        ctx: &mut EnvContext,
318        args: Vec<ParseNode>,
319        _opt_args: Vec<Option<ParseNode>>,
320    ) -> ParseResult<ParseNode> {
321        let colalign = match &args[0] {
322            ParseNode::OrdGroup { body, .. } => body.clone(),
323            other if other.is_symbol_node() => vec![other.clone()],
324            _ => return Err(ParseError::msg("Invalid column alignment for array")),
325        };
326
327        let mut cols = Vec::new();
328        for nde in &colalign {
329            let ca = nde
330                .symbol_text()
331                .ok_or_else(|| ParseError::msg("Expected column alignment character"))?;
332            match ca {
333                "l" | "c" | "r" => cols.push(AlignSpec {
334                    align_type: AlignType::Align,
335                    align: Some(ca.to_string()),
336                    pregap: None,
337                    postgap: None,
338                }),
339                "|" => cols.push(AlignSpec {
340                    align_type: AlignType::Separator,
341                    align: Some("|".to_string()),
342                    pregap: None,
343                    postgap: None,
344                }),
345                ":" => cols.push(AlignSpec {
346                    align_type: AlignType::Separator,
347                    align: Some(":".to_string()),
348                    pregap: None,
349                    postgap: None,
350                }),
351                _ => {
352                    return Err(ParseError::msg(format!(
353                        "Unknown column alignment: {}",
354                        ca
355                    )))
356                }
357            }
358        }
359
360        let max_num_cols = cols.len();
361        let config = ArrayConfig {
362            cols: Some(cols),
363            hskip_before_and_after: Some(true),
364            max_num_cols: Some(max_num_cols),
365            ..Default::default()
366        };
367        parse_array(ctx.parser, config, d_cell_style(&ctx.env_name))
368    }
369
370    for name in &["array", "darray"] {
371        map.insert(
372            name,
373            EnvSpec {
374                num_args: 1,
375                num_optional_args: 0,
376                handler: handle_array,
377            },
378        );
379    }
380}
381
382// ── matrix variants ──────────────────────────────────────────────────────
383
384fn register_matrix(map: &mut HashMap<&'static str, EnvSpec>) {
385    fn handle_matrix(
386        ctx: &mut EnvContext,
387        _args: Vec<ParseNode>,
388        _opt_args: Vec<Option<ParseNode>>,
389    ) -> ParseResult<ParseNode> {
390        let base_name = ctx.env_name.replace('*', "");
391        let delimiters: Option<(&str, &str)> = match base_name.as_str() {
392            "matrix" => None,
393            "pmatrix" => Some(("(", ")")),
394            "bmatrix" => Some(("[", "]")),
395            "Bmatrix" => Some(("\\{", "\\}")),
396            "vmatrix" => Some(("|", "|")),
397            "Vmatrix" => Some(("\\Vert", "\\Vert")),
398            _ => None,
399        };
400
401        let mut col_align = "c".to_string();
402
403        // mathtools starred matrix: parse optional [l|c|r] alignment
404        if ctx.env_name.ends_with('*') {
405            ctx.parser.gullet.consume_spaces();
406            if ctx.parser.gullet.future().text == "[" {
407                ctx.parser.gullet.pop_token();
408                ctx.parser.gullet.consume_spaces();
409                let align_tok = ctx.parser.gullet.pop_token();
410                if !"lcr".contains(align_tok.text.as_str()) {
411                    return Err(ParseError::new(
412                        "Expected l or c or r".to_string(),
413                        Some(&align_tok),
414                    ));
415                }
416                col_align = align_tok.text.clone();
417                ctx.parser.gullet.consume_spaces();
418                let close = ctx.parser.gullet.pop_token();
419                if close.text != "]" {
420                    return Err(ParseError::new(
421                        "Expected ]".to_string(),
422                        Some(&close),
423                    ));
424                }
425            }
426        }
427
428        let config = ArrayConfig {
429            hskip_before_and_after: Some(false),
430            cols: Some(vec![AlignSpec {
431                align_type: AlignType::Align,
432                align: Some(col_align.clone()),
433                pregap: None,
434                postgap: None,
435            }]),
436            ..Default::default()
437        };
438
439        let mut res = parse_array(ctx.parser, config, d_cell_style(&ctx.env_name))?;
440
441        // Fix cols to match actual number of columns
442        if let ParseNode::Array {
443            ref body,
444            ref mut cols,
445            ..
446        } = res
447        {
448            let num_cols = body.iter().map(|r| r.len()).max().unwrap_or(0);
449            *cols = Some(
450                (0..num_cols)
451                    .map(|_| AlignSpec {
452                        align_type: AlignType::Align,
453                        align: Some(col_align.to_string()),
454                        pregap: None,
455                        postgap: None,
456                    })
457                    .collect(),
458            );
459        }
460
461        match delimiters {
462            Some((left, right)) => Ok(ParseNode::LeftRight {
463                mode: ctx.mode,
464                body: vec![res],
465                left: left.to_string(),
466                right: right.to_string(),
467                right_color: None,
468                loc: None,
469            }),
470            None => Ok(res),
471        }
472    }
473
474    for name in &[
475        "matrix", "pmatrix", "bmatrix", "Bmatrix", "vmatrix", "Vmatrix",
476        "matrix*", "pmatrix*", "bmatrix*", "Bmatrix*", "vmatrix*", "Vmatrix*",
477    ] {
478        map.insert(
479            name,
480            EnvSpec {
481                num_args: 0,
482                num_optional_args: 0,
483                handler: handle_matrix,
484            },
485        );
486    }
487}
488
489// ── cases / dcases / rcases / drcases ────────────────────────────────────
490
491fn register_cases(map: &mut HashMap<&'static str, EnvSpec>) {
492    fn handle_cases(
493        ctx: &mut EnvContext,
494        _args: Vec<ParseNode>,
495        _opt_args: Vec<Option<ParseNode>>,
496    ) -> ParseResult<ParseNode> {
497        let config = ArrayConfig {
498            arraystretch: Some(1.2),
499            cols: Some(vec![
500                AlignSpec {
501                    align_type: AlignType::Align,
502                    align: Some("l".to_string()),
503                    pregap: Some(0.0),
504                    postgap: Some(1.0),
505                },
506                AlignSpec {
507                    align_type: AlignType::Align,
508                    align: Some("l".to_string()),
509                    pregap: Some(0.0),
510                    postgap: Some(0.0),
511                },
512            ]),
513            ..Default::default()
514        };
515
516        let res = parse_array(ctx.parser, config, d_cell_style(&ctx.env_name))?;
517
518        let (left, right) = if ctx.env_name.contains('r') {
519            (".", "\\}")
520        } else {
521            ("\\{", ".")
522        };
523
524        Ok(ParseNode::LeftRight {
525            mode: ctx.mode,
526            body: vec![res],
527            left: left.to_string(),
528            right: right.to_string(),
529            right_color: None,
530            loc: None,
531        })
532    }
533
534    for name in &["cases", "dcases", "rcases", "drcases"] {
535        map.insert(
536            name,
537            EnvSpec {
538                num_args: 0,
539                num_optional_args: 0,
540                handler: handle_cases,
541            },
542        );
543    }
544}
545
546// ── align / align* / aligned / split / alignat / alignat* / alignedat ────
547
548fn handle_aligned(
549    ctx: &mut EnvContext,
550    args: Vec<ParseNode>,
551    _opt_args: Vec<Option<ParseNode>>,
552) -> ParseResult<ParseNode> {
553        let is_split = ctx.env_name == "split";
554        let is_alignat = ctx.env_name.contains("at");
555        let sep_type = if is_alignat { "alignat" } else { "align" };
556
557        let config = ArrayConfig {
558            add_jot: Some(true),
559            empty_single_row: true,
560            col_separation_type: Some(sep_type.to_string()),
561            max_num_cols: if is_split { Some(2) } else { None },
562            ..Default::default()
563        };
564
565        let mut res = parse_array(ctx.parser, config, Some(StyleStr::Display))?;
566
567        // Extract explicit column count from first arg (alignat only)
568        let mut num_maths = 0usize;
569        let mut explicit_cols = 0usize;
570        if let Some(ParseNode::OrdGroup { body, .. }) = args.first() {
571            let mut arg_str = String::new();
572            for node in body {
573                if let Some(t) = node.symbol_text() {
574                    arg_str.push_str(t);
575                }
576            }
577            if let Ok(n) = arg_str.parse::<usize>() {
578                num_maths = n;
579                explicit_cols = n * 2;
580            }
581        }
582        let is_aligned = explicit_cols == 0;
583
584        // Determine actual number of columns
585        let mut num_cols = if let ParseNode::Array { ref body, .. } = res {
586            body.iter().map(|r| r.len()).max().unwrap_or(0)
587        } else {
588            0
589        };
590
591        if let ParseNode::Array {
592            body: ref mut array_body,
593            ..
594        } = res
595        {
596            for row in array_body.iter_mut() {
597                // Prepend empty group at every even-indexed cell (2nd, 4th, ...)
598                let mut i = 1;
599                while i < row.len() {
600                    if let ParseNode::Styling {
601                        body: ref mut styling_body,
602                        ..
603                    } = row[i]
604                    {
605                        if let Some(ParseNode::OrdGroup {
606                            body: ref mut og_body,
607                            ..
608                        }) = styling_body.first_mut()
609                        {
610                            og_body.insert(
611                                0,
612                                ParseNode::OrdGroup {
613                                    mode: ctx.mode,
614                                    body: vec![],
615                                    semisimple: None,
616                                    loc: None,
617                                },
618                            );
619                        }
620                    }
621                    i += 2;
622                }
623
624                if !is_aligned {
625                    let cur_maths = row.len() / 2;
626                    if num_maths < cur_maths {
627                        return Err(ParseError::msg(format!(
628                            "Too many math in a row: expected {}, but got {}",
629                            num_maths, cur_maths
630                        )));
631                    }
632                } else if num_cols < row.len() {
633                    num_cols = row.len();
634                }
635            }
636        }
637
638        if !is_aligned {
639            num_cols = explicit_cols;
640        }
641
642        let mut cols = Vec::new();
643        for i in 0..num_cols {
644            let (align, pregap) = if i % 2 == 1 {
645                ("l", 0.0)
646            } else if i > 0 && is_aligned {
647                ("r", 1.0)
648            } else {
649                ("r", 0.0)
650            };
651            cols.push(AlignSpec {
652                align_type: AlignType::Align,
653                align: Some(align.to_string()),
654                pregap: Some(pregap),
655                postgap: Some(0.0),
656            });
657        }
658
659        if let ParseNode::Array {
660            cols: ref mut array_cols,
661            col_separation_type: ref mut array_sep_type,
662            ..
663        } = res
664        {
665            *array_cols = Some(cols);
666            *array_sep_type = Some(
667                if is_aligned { "align" } else { "alignat" }.to_string(),
668            );
669        }
670
671    Ok(res)
672}
673
674fn register_align(map: &mut HashMap<&'static str, EnvSpec>) {
675    for name in &["align", "align*", "aligned", "split"] {
676        map.insert(
677            name,
678            EnvSpec {
679                num_args: 0,
680                num_optional_args: 0,
681                handler: handle_aligned,
682            },
683        );
684    }
685}
686
687// ── gathered / gather / gather* ──────────────────────────────────────────
688
689fn register_gathered(map: &mut HashMap<&'static str, EnvSpec>) {
690    fn handle_gathered(
691        ctx: &mut EnvContext,
692        _args: Vec<ParseNode>,
693        _opt_args: Vec<Option<ParseNode>>,
694    ) -> ParseResult<ParseNode> {
695        let config = ArrayConfig {
696            cols: Some(vec![AlignSpec {
697                align_type: AlignType::Align,
698                align: Some("c".to_string()),
699                pregap: None,
700                postgap: None,
701            }]),
702            add_jot: Some(true),
703            col_separation_type: Some("gather".to_string()),
704            empty_single_row: true,
705            ..Default::default()
706        };
707        parse_array(ctx.parser, config, Some(StyleStr::Display))
708    }
709
710    for name in &["gathered", "gather", "gather*"] {
711        map.insert(
712            name,
713            EnvSpec {
714                num_args: 0,
715                num_optional_args: 0,
716                handler: handle_gathered,
717            },
718        );
719    }
720}
721
722// ── equation / equation* ─────────────────────────────────────────────────
723
724fn register_equation(map: &mut HashMap<&'static str, EnvSpec>) {
725    fn handle_equation(
726        ctx: &mut EnvContext,
727        _args: Vec<ParseNode>,
728        _opt_args: Vec<Option<ParseNode>>,
729    ) -> ParseResult<ParseNode> {
730        let config = ArrayConfig {
731            empty_single_row: true,
732            single_row: true,
733            max_num_cols: Some(1),
734            ..Default::default()
735        };
736        parse_array(ctx.parser, config, Some(StyleStr::Display))
737    }
738
739    for name in &["equation", "equation*"] {
740        map.insert(
741            name,
742            EnvSpec {
743                num_args: 0,
744                num_optional_args: 0,
745                handler: handle_equation,
746            },
747        );
748    }
749}
750
751// ── smallmatrix ──────────────────────────────────────────────────────────
752
753fn register_smallmatrix(map: &mut HashMap<&'static str, EnvSpec>) {
754    fn handle_smallmatrix(
755        ctx: &mut EnvContext,
756        _args: Vec<ParseNode>,
757        _opt_args: Vec<Option<ParseNode>>,
758    ) -> ParseResult<ParseNode> {
759        let config = ArrayConfig {
760            arraystretch: Some(0.5),
761            ..Default::default()
762        };
763        let mut res = parse_array(ctx.parser, config, Some(StyleStr::Script))?;
764        if let ParseNode::Array {
765            ref mut col_separation_type,
766            ..
767        } = res
768        {
769            *col_separation_type = Some("small".to_string());
770        }
771        Ok(res)
772    }
773
774    map.insert(
775        "smallmatrix",
776        EnvSpec {
777            num_args: 0,
778            num_optional_args: 0,
779            handler: handle_smallmatrix,
780        },
781    );
782}
783
784// ── alignat / alignat* / alignedat ──────────────────────────────────────
785
786fn register_alignat(map: &mut HashMap<&'static str, EnvSpec>) {
787    for name in &["alignat", "alignat*", "alignedat"] {
788        map.insert(
789            name,
790            EnvSpec {
791                num_args: 1,
792                num_optional_args: 0,
793                handler: handle_aligned,
794            },
795        );
796    }
797}
798
799// ── CD (amscd commutative diagrams) ──────────────────────────────────────
800
801fn register_cd(map: &mut HashMap<&'static str, EnvSpec>) {
802    fn handle_cd(
803        ctx: &mut EnvContext,
804        _args: Vec<ParseNode>,
805        _opt_args: Vec<Option<ParseNode>>,
806    ) -> ParseResult<ParseNode> {
807        // Collect all raw tokens until \end
808        let mut raw: Vec<Token> = Vec::new();
809        loop {
810            let tok = ctx.parser.gullet.future().clone();
811            if tok.text == "\\end" || tok.text == "EOF" {
812                break;
813            }
814            ctx.parser.gullet.pop_token();
815            raw.push(tok);
816        }
817
818        // Split into rows at \\ or \cr
819        let rows = cd_split_rows(raw);
820
821        let mut body: Vec<Vec<ParseNode>> = Vec::new();
822        let mut row_gaps: Vec<Option<Measurement>> = Vec::new();
823        let mut hlines_before_row: Vec<Vec<bool>> = Vec::new();
824        hlines_before_row.push(vec![]);
825
826        for row_toks in rows {
827            // Skip purely-whitespace rows
828            if row_toks.iter().all(|t| t.text == " ") {
829                continue;
830            }
831            let cells = cd_parse_row(ctx.parser, row_toks)?;
832            if !cells.is_empty() {
833                body.push(cells);
834                row_gaps.push(None);
835                hlines_before_row.push(vec![]);
836            }
837        }
838
839        if body.is_empty() {
840            body.push(vec![]);
841            hlines_before_row.push(vec![]);
842        }
843
844        Ok(ParseNode::Array {
845            mode: ctx.mode,
846            body,
847            row_gaps,
848            hlines_before_row,
849            cols: None,
850            col_separation_type: Some("CD".to_string()),
851            hskip_before_and_after: Some(false),
852            add_jot: None,
853            arraystretch: 1.0,
854            tags: None,
855            leqno: None,
856            is_cd: Some(true),
857            loc: None,
858        })
859    }
860
861    map.insert(
862        "CD",
863        EnvSpec {
864            num_args: 0,
865            num_optional_args: 0,
866            handler: handle_cd,
867        },
868    );
869}
870
871/// Split a flat token list into rows at `\\` or `\cr` boundaries.
872fn cd_split_rows(tokens: Vec<Token>) -> Vec<Vec<Token>> {
873    let mut rows: Vec<Vec<Token>> = Vec::new();
874    let mut current: Vec<Token> = Vec::new();
875    for tok in tokens {
876        if tok.text == "\\\\" || tok.text == "\\cr" {
877            rows.push(current);
878            current = Vec::new();
879        } else {
880            current.push(tok);
881        }
882    }
883    if !current.is_empty() {
884        rows.push(current);
885    }
886    rows
887}
888
889/// Collect tokens from `tokens[start..]` up to (but not including) the first
890/// token whose text equals `delimiter`.  Returns (collected_tokens, tokens_consumed).
891/// `tokens_consumed` includes the delimiter itself if found.
892fn cd_collect_until(tokens: &[Token], start: usize, delimiter: &str) -> (Vec<Token>, usize) {
893    let mut result = Vec::new();
894    let mut i = start;
895    while i < tokens.len() {
896        if tokens[i].text == delimiter {
897            i += 1; // consume the delimiter
898            break;
899        }
900        result.push(tokens[i].clone());
901        i += 1;
902    }
903    (result, i - start)
904}
905
906/// Collect tokens from `tokens[start..]` up to (but not including) the next `@`.
907fn cd_collect_until_at(tokens: &[Token], start: usize) -> (Vec<Token>, usize) {
908    let mut result = Vec::new();
909    let mut i = start;
910    while i < tokens.len() && tokens[i].text != "@" {
911        result.push(tokens[i].clone());
912        i += 1;
913    }
914    (result, i - start)
915}
916
917/// Use the parser to parse a token slice as a math OrdGroup.
918/// Tokens must be in forward order; this function reverses them internally for subparse().
919fn cd_parse_tokens(parser: &mut Parser, tokens: Vec<Token>) -> ParseResult<ParseNode> {
920    // Filter pure whitespace
921    let has_content = tokens.iter().any(|t| t.text != " ");
922    if !has_content {
923        return Ok(ParseNode::OrdGroup {
924            mode: parser.mode,
925            body: vec![],
926            semisimple: None,
927            loc: None,
928        });
929    }
930    // subparse() expects tokens in reverse order (stack convention)
931    let mut rev = tokens;
932    rev.reverse();
933    let body = parser.subparse(rev)?;
934    Ok(ParseNode::OrdGroup {
935        mode: parser.mode,
936        body,
937        semisimple: None,
938        loc: None,
939    })
940}
941
942/// Parse one row of a CD environment from its raw token list.
943/// Returns the list of ParseNode cells for the grid row.
944fn cd_parse_row(parser: &mut Parser, row_tokens: Vec<Token>) -> ParseResult<Vec<ParseNode>> {
945    let toks = &row_tokens;
946    let n = toks.len();
947    let mut cells: Vec<ParseNode> = Vec::new();
948    let mut i = 0usize;
949
950    while i < n {
951        // Skip spaces at start of each cell
952        while i < n && toks[i].text == " " {
953            i += 1;
954        }
955        if i >= n {
956            break;
957        }
958
959        if toks[i].text == "@" {
960            i += 1; // consume `@`
961            if i >= n {
962                return Err(ParseError::msg("Unexpected end of CD row after @"));
963            }
964            let dir = toks[i].text.clone();
965            i += 1; // consume direction char
966
967            let mode = parser.mode;
968            let arrow = match dir.as_str() {
969                ">" | "<" => {
970                    let (above_toks, c1) = cd_collect_until(toks, i, &dir);
971                    i += c1;
972                    let (below_toks, c2) = cd_collect_until(toks, i, &dir);
973                    i += c2;
974                    let label_above = cd_parse_tokens(parser, above_toks)?;
975                    let label_below = cd_parse_tokens(parser, below_toks)?;
976                    ParseNode::CdArrow {
977                        mode,
978                        direction: if dir == ">" { "right" } else { "left" }.to_string(),
979                        label_above: Some(Box::new(label_above)),
980                        label_below: Some(Box::new(label_below)),
981                        loc: None,
982                    }
983                }
984                "V" | "A" => {
985                    let (left_toks, c1) = cd_collect_until(toks, i, &dir);
986                    i += c1;
987                    let (right_toks, c2) = cd_collect_until(toks, i, &dir);
988                    i += c2;
989                    let label_above = cd_parse_tokens(parser, left_toks)?;
990                    let label_below = cd_parse_tokens(parser, right_toks)?;
991                    ParseNode::CdArrow {
992                        mode,
993                        direction: if dir == "V" { "down" } else { "up" }.to_string(),
994                        label_above: Some(Box::new(label_above)),
995                        label_below: Some(Box::new(label_below)),
996                        loc: None,
997                    }
998                }
999                "=" => ParseNode::CdArrow {
1000                    mode,
1001                    direction: "horiz_eq".to_string(),
1002                    label_above: None,
1003                    label_below: None,
1004                    loc: None,
1005                },
1006                "|" => ParseNode::CdArrow {
1007                    mode,
1008                    direction: "vert_eq".to_string(),
1009                    label_above: None,
1010                    label_below: None,
1011                    loc: None,
1012                },
1013                "." => ParseNode::CdArrow {
1014                    mode,
1015                    direction: "none".to_string(),
1016                    label_above: None,
1017                    label_below: None,
1018                    loc: None,
1019                },
1020                _ => return Err(ParseError::msg(format!("Unknown CD directive: @{}", dir))),
1021            };
1022            cells.push(arrow);
1023        } else {
1024            // Object cell: collect until next `@`
1025            let (obj_toks, consumed) = cd_collect_until_at(toks, i);
1026            i += consumed;
1027            let obj = cd_parse_tokens(parser, obj_toks)?;
1028            cells.push(obj);
1029        }
1030    }
1031
1032    // Post-process: structure cells into the (2n-1) grid pattern.
1033    Ok(cd_structure_row(cells, parser.mode))
1034}
1035
1036/// Given the raw parsed cells of one CD row, produce the correctly-structured grid row.
1037///
1038/// Object rows already alternate: obj, h-arrow, obj, h-arrow, …, obj.
1039/// Arrow rows contain only CdArrow nodes (plus whitespace OrdGroups which we strip),
1040/// and need empty OrdGroup fillers inserted between consecutive arrows.
1041fn cd_structure_row(cells: Vec<ParseNode>, mode: Mode) -> Vec<ParseNode> {
1042    // Detect arrow row: all cells are either CdArrow or empty OrdGroup
1043    let is_arrow_row = cells.iter().all(|c| match c {
1044        ParseNode::CdArrow { .. } => true,
1045        ParseNode::OrdGroup { body, .. } => body.is_empty(),
1046        _ => false,
1047    }) && cells.iter().any(|c| matches!(c, ParseNode::CdArrow { .. }));
1048
1049    if is_arrow_row {
1050        let arrows: Vec<ParseNode> = cells
1051            .into_iter()
1052            .filter(|c| matches!(c, ParseNode::CdArrow { .. }))
1053            .collect();
1054
1055        if arrows.is_empty() {
1056            return vec![];
1057        }
1058
1059        let empty = || ParseNode::OrdGroup {
1060            mode,
1061            body: vec![],
1062            semisimple: None,
1063            loc: None,
1064        };
1065
1066        let mut result = Vec::with_capacity(arrows.len() * 2 - 1);
1067        for (idx, arrow) in arrows.into_iter().enumerate() {
1068            if idx > 0 {
1069                result.push(empty());
1070            }
1071            result.push(arrow);
1072        }
1073        result
1074    } else {
1075        // Object row: already in correct format
1076        cells
1077    }
1078}
1079
1080// ── subarray ────────────────────────────────────────────────────────────
1081
1082fn register_subarray(map: &mut HashMap<&'static str, EnvSpec>) {
1083    fn handle_subarray(
1084        ctx: &mut EnvContext,
1085        args: Vec<ParseNode>,
1086        _opt_args: Vec<Option<ParseNode>>,
1087    ) -> ParseResult<ParseNode> {
1088        let colalign = match &args[0] {
1089            ParseNode::OrdGroup { body, .. } => body.clone(),
1090            other if other.is_symbol_node() => vec![other.clone()],
1091            _ => return Err(ParseError::msg("Invalid column alignment for subarray")),
1092        };
1093
1094        let mut cols = Vec::new();
1095        for nde in &colalign {
1096            let ca = nde
1097                .symbol_text()
1098                .ok_or_else(|| ParseError::msg("Expected column alignment character"))?;
1099            match ca {
1100                "l" | "c" => cols.push(AlignSpec {
1101                    align_type: AlignType::Align,
1102                    align: Some(ca.to_string()),
1103                    pregap: None,
1104                    postgap: None,
1105                }),
1106                _ => {
1107                    return Err(ParseError::msg(format!(
1108                        "Unknown column alignment: {}",
1109                        ca
1110                    )))
1111                }
1112            }
1113        }
1114
1115        if cols.len() > 1 {
1116            return Err(ParseError::msg("{subarray} can contain only one column"));
1117        }
1118
1119        let config = ArrayConfig {
1120            cols: Some(cols),
1121            hskip_before_and_after: Some(false),
1122            arraystretch: Some(0.5),
1123            ..Default::default()
1124        };
1125
1126        let res = parse_array(ctx.parser, config, Some(StyleStr::Script))?;
1127
1128        if let ParseNode::Array { ref body, .. } = res {
1129            if !body.is_empty() && body[0].len() > 1 {
1130                return Err(ParseError::msg("{subarray} can contain only one column"));
1131            }
1132        }
1133
1134        Ok(res)
1135    }
1136
1137    map.insert(
1138        "subarray",
1139        EnvSpec {
1140            num_args: 1,
1141            num_optional_args: 0,
1142            handler: handle_subarray,
1143        },
1144    );
1145}