Skip to main content

ratex_parser/
environments.rs

1use std::collections::HashMap;
2
3use ratex_lexer::token::Token;
4
5use crate::error::{ParseError, ParseResult};
6use crate::macro_expander::MacroDefinition;
7use crate::parse_node::{AlignSpec, AlignType, Measurement, Mode, ParseNode, StyleStr};
8use crate::parser::Parser;
9
10// ── Environment registry ─────────────────────────────────────────────────
11
12pub struct EnvContext<'a, 'b> {
13    pub mode: Mode,
14    pub env_name: String,
15    pub parser: &'a mut Parser<'b>,
16}
17
18pub type EnvHandler = fn(
19    ctx: &mut EnvContext,
20    args: Vec<ParseNode>,
21    opt_args: Vec<Option<ParseNode>>,
22) -> ParseResult<ParseNode>;
23
24pub struct EnvSpec {
25    pub num_args: usize,
26    pub num_optional_args: usize,
27    pub handler: EnvHandler,
28}
29
30pub static ENVIRONMENTS: std::sync::LazyLock<HashMap<&'static str, EnvSpec>> =
31    std::sync::LazyLock::new(|| {
32        let mut map = HashMap::new();
33        register_array(&mut map);
34        register_matrix(&mut map);
35        register_cases(&mut map);
36        register_align(&mut map);
37        register_gathered(&mut map);
38        register_equation(&mut map);
39        register_smallmatrix(&mut map);
40        register_alignat(&mut map);
41        register_subarray(&mut map);
42        register_cd(&mut map);
43        map
44    });
45
46// ── ArrayConfig ──────────────────────────────────────────────────────────
47
48#[derive(Default)]
49pub struct ArrayConfig {
50    pub hskip_before_and_after: Option<bool>,
51    pub add_jot: Option<bool>,
52    pub cols: Option<Vec<AlignSpec>>,
53    pub arraystretch: Option<f64>,
54    pub col_separation_type: Option<String>,
55    pub single_row: bool,
56    pub empty_single_row: bool,
57    pub max_num_cols: Option<usize>,
58    pub leqno: Option<bool>,
59}
60
61
62// ── parseArray ───────────────────────────────────────────────────────────
63
64fn get_hlines(parser: &mut Parser) -> ParseResult<Vec<bool>> {
65    let mut hline_info = Vec::new();
66    parser.consume_spaces()?;
67
68    let mut nxt = parser.fetch()?.text.clone();
69    if nxt == "\\relax" {
70        parser.consume();
71        parser.consume_spaces()?;
72        nxt = parser.fetch()?.text.clone();
73    }
74    while nxt == "\\hline" || nxt == "\\hdashline" {
75        parser.consume();
76        hline_info.push(nxt == "\\hdashline");
77        parser.consume_spaces()?;
78        nxt = parser.fetch()?.text.clone();
79    }
80    Ok(hline_info)
81}
82
83fn d_cell_style(env_name: &str) -> Option<StyleStr> {
84    if env_name.starts_with('d') {
85        Some(StyleStr::Display)
86    } else {
87        Some(StyleStr::Text)
88    }
89}
90
91pub fn parse_array(
92    parser: &mut Parser,
93    config: ArrayConfig,
94    style: Option<StyleStr>,
95) -> ParseResult<ParseNode> {
96    parser.gullet.begin_group();
97
98    if !config.single_row {
99        parser
100            .gullet
101            .set_text_macro("\\cr", "\\\\\\relax");
102    }
103
104    let arraystretch = config.arraystretch.unwrap_or_else(|| {
105        // Check if \arraystretch is defined as a macro (e.g., via \def\arraystretch{1.5})
106        if let Some(def) = parser.gullet.get_macro("\\arraystretch") {
107            let s = match def {
108                MacroDefinition::Text(s) => s.clone(),
109                MacroDefinition::Tokens { tokens, .. } => {
110                    // Tokens are stored in reverse order (stack convention for expansion)
111                    tokens.iter().rev().map(|t| t.text.as_str()).collect::<String>()
112                }
113                MacroDefinition::Function(_) => String::new(),
114            };
115            s.parse::<f64>().unwrap_or(1.0)
116        } else {
117            1.0
118        }
119    });
120
121    parser.gullet.begin_group();
122
123    let mut row: Vec<ParseNode> = Vec::new();
124    let mut body: Vec<Vec<ParseNode>> = Vec::new();
125    let mut row_gaps: Vec<Option<Measurement>> = Vec::new();
126    let mut hlines_before_row: Vec<Vec<bool>> = Vec::new();
127
128    hlines_before_row.push(get_hlines(parser)?);
129
130    loop {
131        let break_token = if config.single_row { "\\end" } else { "\\\\" };
132        let cell_body = parser.parse_expression(false, Some(break_token))?;
133        parser.gullet.end_group();
134        parser.gullet.begin_group();
135
136        let mut cell = ParseNode::OrdGroup {
137            mode: parser.mode,
138            body: cell_body,
139            semisimple: None,
140            loc: None,
141        };
142
143        if let Some(s) = style {
144            cell = ParseNode::Styling {
145                mode: parser.mode,
146                style: s,
147                body: vec![cell],
148                loc: None,
149            };
150        }
151
152        row.push(cell.clone());
153        let next = parser.fetch()?.text.clone();
154
155        if next == "&" {
156            if let Some(max) = config.max_num_cols {
157                if row.len() >= max {
158                    return Err(ParseError::msg("Too many tab characters: &"));
159                }
160            }
161            parser.consume();
162        } else if next == "\\end" {
163            // Check for trailing empty row and remove it
164            let is_empty_trailing = if let Some(s) = style {
165                if s == StyleStr::Text || s == StyleStr::Display {
166                    if let ParseNode::Styling { body: ref sb, .. } = cell {
167                        if let Some(ParseNode::OrdGroup {
168                            body: ref ob, ..
169                        }) = sb.first()
170                        {
171                            ob.is_empty()
172                        } else {
173                            false
174                        }
175                    } else {
176                        false
177                    }
178                } else {
179                    false
180                }
181            } else if let ParseNode::OrdGroup { body: ref ob, .. } = cell {
182                ob.is_empty()
183            } else {
184                false
185            };
186
187            body.push(row);
188
189            if is_empty_trailing
190                && (body.len() > 1 || !config.empty_single_row)
191            {
192                body.pop();
193            }
194
195            if hlines_before_row.len() < body.len() + 1 {
196                hlines_before_row.push(vec![]);
197            }
198            break;
199        } else if next == "\\\\" {
200            parser.consume();
201            let size = if parser.gullet.future().text != " " {
202                parser.parse_size_group(true)?
203            } else {
204                None
205            };
206            let gap = size.and_then(|s| {
207                if let ParseNode::Size { value, .. } = s {
208                    Some(value)
209                } else {
210                    None
211                }
212            });
213            row_gaps.push(gap);
214
215            body.push(row);
216            hlines_before_row.push(get_hlines(parser)?);
217            row = Vec::new();
218        } else {
219            return Err(ParseError::msg(format!(
220                "Expected & or \\\\ or \\cr or \\end, got '{}'",
221                next
222            )));
223        }
224    }
225
226    parser.gullet.end_group();
227    parser.gullet.end_group();
228
229    Ok(ParseNode::Array {
230        mode: parser.mode,
231        body,
232        row_gaps,
233        hlines_before_row,
234        cols: config.cols,
235        col_separation_type: config.col_separation_type,
236        hskip_before_and_after: config.hskip_before_and_after,
237        add_jot: config.add_jot,
238        arraystretch,
239        tags: None,
240        leqno: config.leqno,
241        is_cd: None,
242        loc: None,
243    })
244}
245
246// ── array / darray ───────────────────────────────────────────────────────
247
248fn register_array(map: &mut HashMap<&'static str, EnvSpec>) {
249    fn handle_array(
250        ctx: &mut EnvContext,
251        args: Vec<ParseNode>,
252        _opt_args: Vec<Option<ParseNode>>,
253    ) -> ParseResult<ParseNode> {
254        let colalign = match &args[0] {
255            ParseNode::OrdGroup { body, .. } => body.clone(),
256            other if other.is_symbol_node() => vec![other.clone()],
257            _ => return Err(ParseError::msg("Invalid column alignment for array")),
258        };
259
260        let mut cols = Vec::new();
261        for nde in &colalign {
262            let ca = nde
263                .symbol_text()
264                .ok_or_else(|| ParseError::msg("Expected column alignment character"))?;
265            match ca {
266                "l" | "c" | "r" => cols.push(AlignSpec {
267                    align_type: AlignType::Align,
268                    align: Some(ca.to_string()),
269                    pregap: None,
270                    postgap: None,
271                }),
272                "|" => cols.push(AlignSpec {
273                    align_type: AlignType::Separator,
274                    align: Some("|".to_string()),
275                    pregap: None,
276                    postgap: None,
277                }),
278                ":" => cols.push(AlignSpec {
279                    align_type: AlignType::Separator,
280                    align: Some(":".to_string()),
281                    pregap: None,
282                    postgap: None,
283                }),
284                _ => {
285                    return Err(ParseError::msg(format!(
286                        "Unknown column alignment: {}",
287                        ca
288                    )))
289                }
290            }
291        }
292
293        let max_num_cols = cols.len();
294        let config = ArrayConfig {
295            cols: Some(cols),
296            hskip_before_and_after: Some(true),
297            max_num_cols: Some(max_num_cols),
298            ..Default::default()
299        };
300        parse_array(ctx.parser, config, d_cell_style(&ctx.env_name))
301    }
302
303    for name in &["array", "darray"] {
304        map.insert(
305            name,
306            EnvSpec {
307                num_args: 1,
308                num_optional_args: 0,
309                handler: handle_array,
310            },
311        );
312    }
313}
314
315// ── matrix variants ──────────────────────────────────────────────────────
316
317fn register_matrix(map: &mut HashMap<&'static str, EnvSpec>) {
318    fn handle_matrix(
319        ctx: &mut EnvContext,
320        _args: Vec<ParseNode>,
321        _opt_args: Vec<Option<ParseNode>>,
322    ) -> ParseResult<ParseNode> {
323        let base_name = ctx.env_name.replace('*', "");
324        let delimiters: Option<(&str, &str)> = match base_name.as_str() {
325            "matrix" => None,
326            "pmatrix" => Some(("(", ")")),
327            "bmatrix" => Some(("[", "]")),
328            "Bmatrix" => Some(("\\{", "\\}")),
329            "vmatrix" => Some(("|", "|")),
330            "Vmatrix" => Some(("\\Vert", "\\Vert")),
331            _ => None,
332        };
333
334        let mut col_align = "c".to_string();
335
336        // mathtools starred matrix: parse optional [l|c|r] alignment
337        if ctx.env_name.ends_with('*') {
338            ctx.parser.gullet.consume_spaces();
339            if ctx.parser.gullet.future().text == "[" {
340                ctx.parser.gullet.pop_token();
341                ctx.parser.gullet.consume_spaces();
342                let align_tok = ctx.parser.gullet.pop_token();
343                if !"lcr".contains(align_tok.text.as_str()) {
344                    return Err(ParseError::new(
345                        "Expected l or c or r".to_string(),
346                        Some(&align_tok),
347                    ));
348                }
349                col_align = align_tok.text.clone();
350                ctx.parser.gullet.consume_spaces();
351                let close = ctx.parser.gullet.pop_token();
352                if close.text != "]" {
353                    return Err(ParseError::new(
354                        "Expected ]".to_string(),
355                        Some(&close),
356                    ));
357                }
358            }
359        }
360
361        let config = ArrayConfig {
362            hskip_before_and_after: Some(false),
363            cols: Some(vec![AlignSpec {
364                align_type: AlignType::Align,
365                align: Some(col_align.clone()),
366                pregap: None,
367                postgap: None,
368            }]),
369            ..Default::default()
370        };
371
372        let mut res = parse_array(ctx.parser, config, d_cell_style(&ctx.env_name))?;
373
374        // Fix cols to match actual number of columns
375        if let ParseNode::Array {
376            ref body,
377            ref mut cols,
378            ..
379        } = res
380        {
381            let num_cols = body.iter().map(|r| r.len()).max().unwrap_or(0);
382            *cols = Some(
383                (0..num_cols)
384                    .map(|_| AlignSpec {
385                        align_type: AlignType::Align,
386                        align: Some(col_align.to_string()),
387                        pregap: None,
388                        postgap: None,
389                    })
390                    .collect(),
391            );
392        }
393
394        match delimiters {
395            Some((left, right)) => Ok(ParseNode::LeftRight {
396                mode: ctx.mode,
397                body: vec![res],
398                left: left.to_string(),
399                right: right.to_string(),
400                right_color: None,
401                loc: None,
402            }),
403            None => Ok(res),
404        }
405    }
406
407    for name in &[
408        "matrix", "pmatrix", "bmatrix", "Bmatrix", "vmatrix", "Vmatrix",
409        "matrix*", "pmatrix*", "bmatrix*", "Bmatrix*", "vmatrix*", "Vmatrix*",
410    ] {
411        map.insert(
412            name,
413            EnvSpec {
414                num_args: 0,
415                num_optional_args: 0,
416                handler: handle_matrix,
417            },
418        );
419    }
420}
421
422// ── cases / dcases / rcases / drcases ────────────────────────────────────
423
424fn register_cases(map: &mut HashMap<&'static str, EnvSpec>) {
425    fn handle_cases(
426        ctx: &mut EnvContext,
427        _args: Vec<ParseNode>,
428        _opt_args: Vec<Option<ParseNode>>,
429    ) -> ParseResult<ParseNode> {
430        let config = ArrayConfig {
431            arraystretch: Some(1.2),
432            cols: Some(vec![
433                AlignSpec {
434                    align_type: AlignType::Align,
435                    align: Some("l".to_string()),
436                    pregap: Some(0.0),
437                    postgap: Some(1.0),
438                },
439                AlignSpec {
440                    align_type: AlignType::Align,
441                    align: Some("l".to_string()),
442                    pregap: Some(0.0),
443                    postgap: Some(0.0),
444                },
445            ]),
446            ..Default::default()
447        };
448
449        let res = parse_array(ctx.parser, config, d_cell_style(&ctx.env_name))?;
450
451        let (left, right) = if ctx.env_name.contains('r') {
452            (".", "\\}")
453        } else {
454            ("\\{", ".")
455        };
456
457        Ok(ParseNode::LeftRight {
458            mode: ctx.mode,
459            body: vec![res],
460            left: left.to_string(),
461            right: right.to_string(),
462            right_color: None,
463            loc: None,
464        })
465    }
466
467    for name in &["cases", "dcases", "rcases", "drcases"] {
468        map.insert(
469            name,
470            EnvSpec {
471                num_args: 0,
472                num_optional_args: 0,
473                handler: handle_cases,
474            },
475        );
476    }
477}
478
479// ── align / align* / aligned / split / alignat / alignat* / alignedat ────
480
481fn handle_aligned(
482    ctx: &mut EnvContext,
483    args: Vec<ParseNode>,
484    _opt_args: Vec<Option<ParseNode>>,
485) -> ParseResult<ParseNode> {
486        let is_split = ctx.env_name == "split";
487        let is_alignat = ctx.env_name.contains("at");
488        let sep_type = if is_alignat { "alignat" } else { "align" };
489
490        let config = ArrayConfig {
491            add_jot: Some(true),
492            empty_single_row: true,
493            col_separation_type: Some(sep_type.to_string()),
494            max_num_cols: if is_split { Some(2) } else { None },
495            ..Default::default()
496        };
497
498        let mut res = parse_array(ctx.parser, config, Some(StyleStr::Display))?;
499
500        // Extract explicit column count from first arg (alignat only)
501        let mut num_maths = 0usize;
502        let mut explicit_cols = 0usize;
503        if let Some(ParseNode::OrdGroup { body, .. }) = args.first() {
504            let mut arg_str = String::new();
505            for node in body {
506                if let Some(t) = node.symbol_text() {
507                    arg_str.push_str(t);
508                }
509            }
510            if let Ok(n) = arg_str.parse::<usize>() {
511                num_maths = n;
512                explicit_cols = n * 2;
513            }
514        }
515        let is_aligned = explicit_cols == 0;
516
517        // Determine actual number of columns
518        let mut num_cols = if let ParseNode::Array { ref body, .. } = res {
519            body.iter().map(|r| r.len()).max().unwrap_or(0)
520        } else {
521            0
522        };
523
524        if let ParseNode::Array {
525            body: ref mut array_body,
526            ..
527        } = res
528        {
529            for row in array_body.iter_mut() {
530                // Prepend empty group at every even-indexed cell (2nd, 4th, ...)
531                let mut i = 1;
532                while i < row.len() {
533                    if let ParseNode::Styling {
534                        body: ref mut styling_body,
535                        ..
536                    } = row[i]
537                    {
538                        if let Some(ParseNode::OrdGroup {
539                            body: ref mut og_body,
540                            ..
541                        }) = styling_body.first_mut()
542                        {
543                            og_body.insert(
544                                0,
545                                ParseNode::OrdGroup {
546                                    mode: ctx.mode,
547                                    body: vec![],
548                                    semisimple: None,
549                                    loc: None,
550                                },
551                            );
552                        }
553                    }
554                    i += 2;
555                }
556
557                if !is_aligned {
558                    let cur_maths = row.len() / 2;
559                    if num_maths < cur_maths {
560                        return Err(ParseError::msg(format!(
561                            "Too many math in a row: expected {}, but got {}",
562                            num_maths, cur_maths
563                        )));
564                    }
565                } else if num_cols < row.len() {
566                    num_cols = row.len();
567                }
568            }
569        }
570
571        if !is_aligned {
572            num_cols = explicit_cols;
573        }
574
575        let mut cols = Vec::new();
576        for i in 0..num_cols {
577            let (align, pregap) = if i % 2 == 1 {
578                ("l", 0.0)
579            } else if i > 0 && is_aligned {
580                ("r", 1.0)
581            } else {
582                ("r", 0.0)
583            };
584            cols.push(AlignSpec {
585                align_type: AlignType::Align,
586                align: Some(align.to_string()),
587                pregap: Some(pregap),
588                postgap: Some(0.0),
589            });
590        }
591
592        if let ParseNode::Array {
593            cols: ref mut array_cols,
594            col_separation_type: ref mut array_sep_type,
595            ..
596        } = res
597        {
598            *array_cols = Some(cols);
599            *array_sep_type = Some(
600                if is_aligned { "align" } else { "alignat" }.to_string(),
601            );
602        }
603
604    Ok(res)
605}
606
607fn register_align(map: &mut HashMap<&'static str, EnvSpec>) {
608    for name in &["align", "align*", "aligned", "split"] {
609        map.insert(
610            name,
611            EnvSpec {
612                num_args: 0,
613                num_optional_args: 0,
614                handler: handle_aligned,
615            },
616        );
617    }
618}
619
620// ── gathered / gather / gather* ──────────────────────────────────────────
621
622fn register_gathered(map: &mut HashMap<&'static str, EnvSpec>) {
623    fn handle_gathered(
624        ctx: &mut EnvContext,
625        _args: Vec<ParseNode>,
626        _opt_args: Vec<Option<ParseNode>>,
627    ) -> ParseResult<ParseNode> {
628        let config = ArrayConfig {
629            cols: Some(vec![AlignSpec {
630                align_type: AlignType::Align,
631                align: Some("c".to_string()),
632                pregap: None,
633                postgap: None,
634            }]),
635            add_jot: Some(true),
636            col_separation_type: Some("gather".to_string()),
637            empty_single_row: true,
638            ..Default::default()
639        };
640        parse_array(ctx.parser, config, Some(StyleStr::Display))
641    }
642
643    for name in &["gathered", "gather", "gather*"] {
644        map.insert(
645            name,
646            EnvSpec {
647                num_args: 0,
648                num_optional_args: 0,
649                handler: handle_gathered,
650            },
651        );
652    }
653}
654
655// ── equation / equation* ─────────────────────────────────────────────────
656
657fn register_equation(map: &mut HashMap<&'static str, EnvSpec>) {
658    fn handle_equation(
659        ctx: &mut EnvContext,
660        _args: Vec<ParseNode>,
661        _opt_args: Vec<Option<ParseNode>>,
662    ) -> ParseResult<ParseNode> {
663        let config = ArrayConfig {
664            empty_single_row: true,
665            single_row: true,
666            max_num_cols: Some(1),
667            ..Default::default()
668        };
669        parse_array(ctx.parser, config, Some(StyleStr::Display))
670    }
671
672    for name in &["equation", "equation*"] {
673        map.insert(
674            name,
675            EnvSpec {
676                num_args: 0,
677                num_optional_args: 0,
678                handler: handle_equation,
679            },
680        );
681    }
682}
683
684// ── smallmatrix ──────────────────────────────────────────────────────────
685
686fn register_smallmatrix(map: &mut HashMap<&'static str, EnvSpec>) {
687    fn handle_smallmatrix(
688        ctx: &mut EnvContext,
689        _args: Vec<ParseNode>,
690        _opt_args: Vec<Option<ParseNode>>,
691    ) -> ParseResult<ParseNode> {
692        let config = ArrayConfig {
693            arraystretch: Some(0.5),
694            ..Default::default()
695        };
696        let mut res = parse_array(ctx.parser, config, Some(StyleStr::Script))?;
697        if let ParseNode::Array {
698            ref mut col_separation_type,
699            ..
700        } = res
701        {
702            *col_separation_type = Some("small".to_string());
703        }
704        Ok(res)
705    }
706
707    map.insert(
708        "smallmatrix",
709        EnvSpec {
710            num_args: 0,
711            num_optional_args: 0,
712            handler: handle_smallmatrix,
713        },
714    );
715}
716
717// ── alignat / alignat* / alignedat ──────────────────────────────────────
718
719fn register_alignat(map: &mut HashMap<&'static str, EnvSpec>) {
720    for name in &["alignat", "alignat*", "alignedat"] {
721        map.insert(
722            name,
723            EnvSpec {
724                num_args: 1,
725                num_optional_args: 0,
726                handler: handle_aligned,
727            },
728        );
729    }
730}
731
732// ── CD (amscd commutative diagrams) ──────────────────────────────────────
733
734fn register_cd(map: &mut HashMap<&'static str, EnvSpec>) {
735    fn handle_cd(
736        ctx: &mut EnvContext,
737        _args: Vec<ParseNode>,
738        _opt_args: Vec<Option<ParseNode>>,
739    ) -> ParseResult<ParseNode> {
740        // Collect all raw tokens until \end
741        let mut raw: Vec<Token> = Vec::new();
742        loop {
743            let tok = ctx.parser.gullet.future().clone();
744            if tok.text == "\\end" || tok.text == "EOF" {
745                break;
746            }
747            ctx.parser.gullet.pop_token();
748            raw.push(tok);
749        }
750
751        // Split into rows at \\ or \cr
752        let rows = cd_split_rows(raw);
753
754        let mut body: Vec<Vec<ParseNode>> = Vec::new();
755        let mut row_gaps: Vec<Option<Measurement>> = Vec::new();
756        let mut hlines_before_row: Vec<Vec<bool>> = Vec::new();
757        hlines_before_row.push(vec![]);
758
759        for row_toks in rows {
760            // Skip purely-whitespace rows
761            if row_toks.iter().all(|t| t.text == " ") {
762                continue;
763            }
764            let cells = cd_parse_row(ctx.parser, row_toks)?;
765            if !cells.is_empty() {
766                body.push(cells);
767                row_gaps.push(None);
768                hlines_before_row.push(vec![]);
769            }
770        }
771
772        if body.is_empty() {
773            body.push(vec![]);
774            hlines_before_row.push(vec![]);
775        }
776
777        Ok(ParseNode::Array {
778            mode: ctx.mode,
779            body,
780            row_gaps,
781            hlines_before_row,
782            cols: None,
783            col_separation_type: Some("CD".to_string()),
784            hskip_before_and_after: Some(false),
785            add_jot: None,
786            arraystretch: 1.0,
787            tags: None,
788            leqno: None,
789            is_cd: Some(true),
790            loc: None,
791        })
792    }
793
794    map.insert(
795        "CD",
796        EnvSpec {
797            num_args: 0,
798            num_optional_args: 0,
799            handler: handle_cd,
800        },
801    );
802}
803
804/// Split a flat token list into rows at `\\` or `\cr` boundaries.
805fn cd_split_rows(tokens: Vec<Token>) -> Vec<Vec<Token>> {
806    let mut rows: Vec<Vec<Token>> = Vec::new();
807    let mut current: Vec<Token> = Vec::new();
808    for tok in tokens {
809        if tok.text == "\\\\" || tok.text == "\\cr" {
810            rows.push(current);
811            current = Vec::new();
812        } else {
813            current.push(tok);
814        }
815    }
816    if !current.is_empty() {
817        rows.push(current);
818    }
819    rows
820}
821
822/// Collect tokens from `tokens[start..]` up to (but not including) the first
823/// token whose text equals `delimiter`.  Returns (collected_tokens, tokens_consumed).
824/// `tokens_consumed` includes the delimiter itself if found.
825fn cd_collect_until(tokens: &[Token], start: usize, delimiter: &str) -> (Vec<Token>, usize) {
826    let mut result = Vec::new();
827    let mut i = start;
828    while i < tokens.len() {
829        if tokens[i].text == delimiter {
830            i += 1; // consume the delimiter
831            break;
832        }
833        result.push(tokens[i].clone());
834        i += 1;
835    }
836    (result, i - start)
837}
838
839/// Collect tokens from `tokens[start..]` up to (but not including) the next `@`.
840fn cd_collect_until_at(tokens: &[Token], start: usize) -> (Vec<Token>, usize) {
841    let mut result = Vec::new();
842    let mut i = start;
843    while i < tokens.len() && tokens[i].text != "@" {
844        result.push(tokens[i].clone());
845        i += 1;
846    }
847    (result, i - start)
848}
849
850/// Use the parser to parse a token slice as a math OrdGroup.
851/// Tokens must be in forward order; this function reverses them internally for subparse().
852fn cd_parse_tokens(parser: &mut Parser, tokens: Vec<Token>) -> ParseResult<ParseNode> {
853    // Filter pure whitespace
854    let has_content = tokens.iter().any(|t| t.text != " ");
855    if !has_content {
856        return Ok(ParseNode::OrdGroup {
857            mode: parser.mode,
858            body: vec![],
859            semisimple: None,
860            loc: None,
861        });
862    }
863    // subparse() expects tokens in reverse order (stack convention)
864    let mut rev = tokens;
865    rev.reverse();
866    let body = parser.subparse(rev)?;
867    Ok(ParseNode::OrdGroup {
868        mode: parser.mode,
869        body,
870        semisimple: None,
871        loc: None,
872    })
873}
874
875/// Parse one row of a CD environment from its raw token list.
876/// Returns the list of ParseNode cells for the grid row.
877fn cd_parse_row(parser: &mut Parser, row_tokens: Vec<Token>) -> ParseResult<Vec<ParseNode>> {
878    let toks = &row_tokens;
879    let n = toks.len();
880    let mut cells: Vec<ParseNode> = Vec::new();
881    let mut i = 0usize;
882
883    while i < n {
884        // Skip spaces at start of each cell
885        while i < n && toks[i].text == " " {
886            i += 1;
887        }
888        if i >= n {
889            break;
890        }
891
892        if toks[i].text == "@" {
893            i += 1; // consume `@`
894            if i >= n {
895                return Err(ParseError::msg("Unexpected end of CD row after @"));
896            }
897            let dir = toks[i].text.clone();
898            i += 1; // consume direction char
899
900            let mode = parser.mode;
901            let arrow = match dir.as_str() {
902                ">" | "<" => {
903                    let (above_toks, c1) = cd_collect_until(toks, i, &dir);
904                    i += c1;
905                    let (below_toks, c2) = cd_collect_until(toks, i, &dir);
906                    i += c2;
907                    let label_above = cd_parse_tokens(parser, above_toks)?;
908                    let label_below = cd_parse_tokens(parser, below_toks)?;
909                    ParseNode::CdArrow {
910                        mode,
911                        direction: if dir == ">" { "right" } else { "left" }.to_string(),
912                        label_above: Some(Box::new(label_above)),
913                        label_below: Some(Box::new(label_below)),
914                        loc: None,
915                    }
916                }
917                "V" | "A" => {
918                    let (left_toks, c1) = cd_collect_until(toks, i, &dir);
919                    i += c1;
920                    let (right_toks, c2) = cd_collect_until(toks, i, &dir);
921                    i += c2;
922                    let label_above = cd_parse_tokens(parser, left_toks)?;
923                    let label_below = cd_parse_tokens(parser, right_toks)?;
924                    ParseNode::CdArrow {
925                        mode,
926                        direction: if dir == "V" { "down" } else { "up" }.to_string(),
927                        label_above: Some(Box::new(label_above)),
928                        label_below: Some(Box::new(label_below)),
929                        loc: None,
930                    }
931                }
932                "=" => ParseNode::CdArrow {
933                    mode,
934                    direction: "horiz_eq".to_string(),
935                    label_above: None,
936                    label_below: None,
937                    loc: None,
938                },
939                "|" => ParseNode::CdArrow {
940                    mode,
941                    direction: "vert_eq".to_string(),
942                    label_above: None,
943                    label_below: None,
944                    loc: None,
945                },
946                "." => ParseNode::CdArrow {
947                    mode,
948                    direction: "none".to_string(),
949                    label_above: None,
950                    label_below: None,
951                    loc: None,
952                },
953                _ => return Err(ParseError::msg(format!("Unknown CD directive: @{}", dir))),
954            };
955            cells.push(arrow);
956        } else {
957            // Object cell: collect until next `@`
958            let (obj_toks, consumed) = cd_collect_until_at(toks, i);
959            i += consumed;
960            let obj = cd_parse_tokens(parser, obj_toks)?;
961            cells.push(obj);
962        }
963    }
964
965    // Post-process: structure cells into the (2n-1) grid pattern.
966    Ok(cd_structure_row(cells, parser.mode))
967}
968
969/// Given the raw parsed cells of one CD row, produce the correctly-structured grid row.
970///
971/// Object rows already alternate: obj, h-arrow, obj, h-arrow, …, obj.
972/// Arrow rows contain only CdArrow nodes (plus whitespace OrdGroups which we strip),
973/// and need empty OrdGroup fillers inserted between consecutive arrows.
974fn cd_structure_row(cells: Vec<ParseNode>, mode: Mode) -> Vec<ParseNode> {
975    // Detect arrow row: all cells are either CdArrow or empty OrdGroup
976    let is_arrow_row = cells.iter().all(|c| match c {
977        ParseNode::CdArrow { .. } => true,
978        ParseNode::OrdGroup { body, .. } => body.is_empty(),
979        _ => false,
980    }) && cells.iter().any(|c| matches!(c, ParseNode::CdArrow { .. }));
981
982    if is_arrow_row {
983        let arrows: Vec<ParseNode> = cells
984            .into_iter()
985            .filter(|c| matches!(c, ParseNode::CdArrow { .. }))
986            .collect();
987
988        if arrows.is_empty() {
989            return vec![];
990        }
991
992        let empty = || ParseNode::OrdGroup {
993            mode,
994            body: vec![],
995            semisimple: None,
996            loc: None,
997        };
998
999        let mut result = Vec::with_capacity(arrows.len() * 2 - 1);
1000        for (idx, arrow) in arrows.into_iter().enumerate() {
1001            if idx > 0 {
1002                result.push(empty());
1003            }
1004            result.push(arrow);
1005        }
1006        result
1007    } else {
1008        // Object row: already in correct format
1009        cells
1010    }
1011}
1012
1013// ── subarray ────────────────────────────────────────────────────────────
1014
1015fn register_subarray(map: &mut HashMap<&'static str, EnvSpec>) {
1016    fn handle_subarray(
1017        ctx: &mut EnvContext,
1018        args: Vec<ParseNode>,
1019        _opt_args: Vec<Option<ParseNode>>,
1020    ) -> ParseResult<ParseNode> {
1021        let colalign = match &args[0] {
1022            ParseNode::OrdGroup { body, .. } => body.clone(),
1023            other if other.is_symbol_node() => vec![other.clone()],
1024            _ => return Err(ParseError::msg("Invalid column alignment for subarray")),
1025        };
1026
1027        let mut cols = Vec::new();
1028        for nde in &colalign {
1029            let ca = nde
1030                .symbol_text()
1031                .ok_or_else(|| ParseError::msg("Expected column alignment character"))?;
1032            match ca {
1033                "l" | "c" => cols.push(AlignSpec {
1034                    align_type: AlignType::Align,
1035                    align: Some(ca.to_string()),
1036                    pregap: None,
1037                    postgap: None,
1038                }),
1039                _ => {
1040                    return Err(ParseError::msg(format!(
1041                        "Unknown column alignment: {}",
1042                        ca
1043                    )))
1044                }
1045            }
1046        }
1047
1048        if cols.len() > 1 {
1049            return Err(ParseError::msg("{subarray} can contain only one column"));
1050        }
1051
1052        let config = ArrayConfig {
1053            cols: Some(cols),
1054            hskip_before_and_after: Some(false),
1055            arraystretch: Some(0.5),
1056            ..Default::default()
1057        };
1058
1059        let res = parse_array(ctx.parser, config, Some(StyleStr::Script))?;
1060
1061        if let ParseNode::Array { ref body, .. } = res {
1062            if !body.is_empty() && body[0].len() > 1 {
1063                return Err(ParseError::msg("{subarray} can contain only one column"));
1064            }
1065        }
1066
1067        Ok(res)
1068    }
1069
1070    map.insert(
1071        "subarray",
1072        EnvSpec {
1073            num_args: 1,
1074            num_optional_args: 0,
1075            handler: handle_subarray,
1076        },
1077    );
1078}